diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index a44bb61..21af863 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -7,9 +7,13 @@ "userEnvProbe": "loginShell", // Add the IDs of extensions you want installed when the container is created. - "extensions": [ - "pranayagarwal.vscode-hack" - ], + "customizations": { + "vscode": { + "extensions": [ + "pranayagarwal.vscode-hack" + ] + } + }, "mounts": [], diff --git a/.hhconfig b/.hhconfig index 6f6d160..ab337fa 100644 --- a/.hhconfig +++ b/.hhconfig @@ -1,5 +1,4 @@ hackfmt.line_width=120 +hackfmt.tabs=true allowed_decl_fixme_codes=2053,3012,4045,4047,4341 allowed_fixme_codes_strict=2011,2049,2050,2053,2083,3012,3084,4027,4038,4045,4047,4104,4105,4106,4107,4108,4110,4128,4135,4188,4223,4240,4323,4341,4390,4401 - - diff --git a/bin/setup-devcontainer b/bin/setup-devcontainer index 6e5efb9..dc10b4f 100755 --- a/bin/setup-devcontainer +++ b/bin/setup-devcontainer @@ -39,6 +39,6 @@ echo "" # Install dependencies echo "Installing dependencies..." -php composer.phar install +php composer.phar install --ignore-platform-reqs echo "" diff --git a/src/AsyncMysql/AsyncMysqlClient.php b/src/AsyncMysql/AsyncMysqlClient.php index 20e0d50..220f3bc 100644 --- a/src/AsyncMysql/AsyncMysqlClient.php +++ b/src/AsyncMysql/AsyncMysqlClient.php @@ -25,7 +25,7 @@ public static function setPoolsConnectionLimit(int $_limit): void {} int $_tcp_timeout_micros = 0, string $_sni_server_name = '', string $_server_cert_extension = '', - string $_server_cert_values = '', + string $_server_cert_values = '', ): Awaitable<\AsyncMysqlConnection> { return new AsyncMysqlConnection($host, $port, $dbname); } diff --git a/src/AsyncMysql/AsyncMysqlClientStats.php b/src/AsyncMysql/AsyncMysqlClientStats.php index 27495d5..315cadb 100644 --- a/src/AsyncMysql/AsyncMysqlClientStats.php +++ b/src/AsyncMysql/AsyncMysqlClientStats.php @@ -9,27 +9,27 @@ <<__MockClass>> final class AsyncMysqlClientStats extends \AsyncMysqlClientStats { - /* HH_IGNORE_ERROR[3012] I don't want to call parent::construct */ - public function __construct() {} + /* HH_IGNORE_ERROR[3012] I don't want to call parent::construct */ + public function __construct() {} - <<__Override>> - public function ioEventLoopMicrosAvg(): float { - return 0.0; - } - <<__Override>> - public function callbackDelayMicrosAvg(): float { - return 0.0; - } - <<__Override>> - public function ioThreadBusyMicrosAvg(): float { - return 0.0; - } - <<__Override>> - public function ioThreadIdleMicrosAvg(): float { - return 0.0; - } - <<__Override>> - public function notificationQueueSize(): int { - return 0; - } + <<__Override>> + public function ioEventLoopMicrosAvg(): float { + return 0.0; + } + <<__Override>> + public function callbackDelayMicrosAvg(): float { + return 0.0; + } + <<__Override>> + public function ioThreadBusyMicrosAvg(): float { + return 0.0; + } + <<__Override>> + public function ioThreadIdleMicrosAvg(): float { + return 0.0; + } + <<__Override>> + public function notificationQueueSize(): int { + return 0; + } } diff --git a/src/AsyncMysql/AsyncMysqlConnectResult.php b/src/AsyncMysql/AsyncMysqlConnectResult.php index fecd331..4ae0a27 100644 --- a/src/AsyncMysql/AsyncMysqlConnectResult.php +++ b/src/AsyncMysql/AsyncMysqlConnectResult.php @@ -4,35 +4,35 @@ <<__MockClass>> final class AsyncMysqlConnectResult extends \AsyncMysqlConnectResult { - private float $elapsed; - private float $start; + private float $elapsed; + private float $start; - /* HH_IGNORE_ERROR[3012] I don't want to call parent::construct */ - public function __construct(bool $from_pool) { - // pretend connections take longer if they don't come from the pool - if ($from_pool) { - $this->elapsed = .001; - } else { - $this->elapsed = .01; - } - $this->start = \microtime(true); - } + /* HH_IGNORE_ERROR[3012] I don't want to call parent::construct */ + public function __construct(bool $from_pool) { + // pretend connections take longer if they don't come from the pool + if ($from_pool) { + $this->elapsed = .001; + } else { + $this->elapsed = .01; + } + $this->start = \microtime(true); + } - <<__Override>> - public function elapsedMicros(): int { - return (int)($this->elapsed * 1000000); - } - <<__Override>> - public function startTime(): float { - return $this->start; - } - <<__Override>> - public function endTime(): float { - return $this->start + $this->elapsed; - } + <<__Override>> + public function elapsedMicros(): int { + return (int)($this->elapsed * 1000000); + } + <<__Override>> + public function startTime(): float { + return $this->start; + } + <<__Override>> + public function endTime(): float { + return $this->start + $this->elapsed; + } - <<__Override>> - public function clientStats(): \AsyncMysqlClientStats { - return new AsyncMysqlClientStats(); - } + <<__Override>> + public function clientStats(): \AsyncMysqlClientStats { + return new AsyncMysqlClientStats(); + } } diff --git a/src/AsyncMysql/AsyncMysqlConnection.php b/src/AsyncMysql/AsyncMysqlConnection.php index 7855d63..96d937e 100644 --- a/src/AsyncMysql/AsyncMysqlConnection.php +++ b/src/AsyncMysql/AsyncMysqlConnection.php @@ -7,155 +7,155 @@ <<__MockClass>> final class AsyncMysqlConnection extends \AsyncMysqlConnection { - private bool $open = true; - private bool $reusable = true; - private AsyncMysqlConnectResult $result; - - /** - * Not part of the built-in AsyncMysqlConnection - */ - private Server $server; - private QueryStringifier $queryStringifier; - - public function getServer(): Server { - return $this->server; - } - - public function getDatabase(): string { - return $this->dbname; - } - - public function setDatabase(string $dbname): void { - $this->dbname = $dbname; - } - - /* HH_IGNORE_ERROR[3012] I don't want to call parent::construct */ - public function __construct(private string $host, private int $port, private string $dbname, ?QueryStringifier $query_stringifier = null) { - $this->server = Server::getOrCreate($host); - $this->result = new AsyncMysqlConnectResult(false); - $this->queryStringifier = $query_stringifier ?? QueryStringifier::createForTypesafeHack(); - } - - <<__Override>> - public async function query( - string $query, - int $_timeout_micros = -1, - dict $_query_attributes = dict[], - ): Awaitable { - Logger::log(Verbosity::QUERIES, "SQLFake [verbose]: $query"); - - $config = $this->server->config; - $strict_sql_before = QueryContext::$strictSQLMode; - if ($config['strict_sql_mode'] ?? false) { - QueryContext::$strictSQLMode = true; - } - - $strict_schema_before = QueryContext::$strictSchemaMode; - if ($config['strict_schema_mode'] ?? false) { - QueryContext::$strictSchemaMode = true; - } - - if (($config['inherit_schema_from'] ?? '') !== '') { - $this->dbname = $config['inherit_schema_from'] ?? ''; - } - - try { - list($results, $rows_affected) = SQLCommandProcessor::execute($query, $this); - } catch (\Exception $e) { - // this makes debugging a failing unit test easier, show the actual query that failed parsing along with the parser error - QueryContext::$strictSQLMode = $strict_sql_before; - QueryContext::$strictSchemaMode = $strict_schema_before; - $msg = $e->getMessage(); - $type = \get_class($e); - Logger::log(Verbosity::QUIET, "SQL Fake $type: $msg in SQL query: $query"); - throw $e; - } - QueryContext::$strictSQLMode = $strict_sql_before; - QueryContext::$strictSchemaMode = $strict_schema_before; - Logger::logResult($this->getServer()->name, $results, $rows_affected); - return new AsyncMysqlQueryResult(vec($results), $rows_affected); - } - - <<__Override>> - public async function queryAsync(SQL\Query $query): Awaitable { - return await $this->query($this->queryStringifier->formatQuery($query)); - } - - <<__Override>> - public function queryf(\HH\FormatString<\HH\SQLFormatter> $query, mixed ...$args): Awaitable { - invariant($query is string, '\\HH\\FormatString<_> is opaque, but we need it to be a string here.'); - return $this->query($this->queryStringifier->formatString($query, vec($args))); - } - - <<__Override>> - public async function multiQuery( - Traversable $queries, - int $_timeout_micros = -1, - dict $_query_attributes = dict[], - ): Awaitable> { - $results = await Vec\map_async($queries, $query ==> $this->query($query)); - return Vector::fromItems($results); - } - - <<__Override>> - public function escapeString(string $data): string { - // not actually escaping obviously - return $data; - } - - <<__Override>> - public function close(): void { - $this->open = false; - } - - <<__Override>> - public function releaseConnection(): void {} - - <<__Override>> - public function isValid(): bool { - return $this->open; - } - - <<__Override>> - public function serverInfo(): string { - // Copied from https://docs.hhvm.com/hack/reference/class/AsyncMysqlConnection/serverInfo/ - return '5.6.24-fb-log-slackhq-sql-fake'; - } - - <<__Override>> - public function warningCount(): int { - return 0; - } - - <<__Override>> - public function host(): string { - return $this->host; - } - - <<__Override>> - public function port(): int { - return $this->port; - } - - <<__Override>> - public function setReusable(bool $reusable): void { - $this->reusable = $reusable; - } - - <<__Override>> - public function isReusable(): bool { - return $this->reusable; - } - - <<__Override>> - public function lastActivityTime(): float { - // A float representing the number of seconds ago since epoch that we had successful activity on the current connection. - // 50 ms ago seems like a reasonable answer. - return \microtime(true) - 0.05; - } - - <<__Override>> - public function connectResult(): \AsyncMysqlConnectResult { - return $this->result; - } + private bool $open = true; + private bool $reusable = true; + private AsyncMysqlConnectResult $result; + + /** + * Not part of the built-in AsyncMysqlConnection + */ + private Server $server; + private QueryStringifier $queryStringifier; + + public function getServer(): Server { + return $this->server; + } + + public function getDatabase(): string { + return $this->dbname; + } + + public function setDatabase(string $dbname): void { + $this->dbname = $dbname; + } + + /* HH_IGNORE_ERROR[3012] I don't want to call parent::construct */ + public function __construct(private string $host, private int $port, private string $dbname, ?QueryStringifier $query_stringifier = null) { + $this->server = Server::getOrCreate($host); + $this->result = new AsyncMysqlConnectResult(false); + $this->queryStringifier = $query_stringifier ?? QueryStringifier::createForTypesafeHack(); + } + + <<__Override>> + public async function query( + string $query, + int $_timeout_micros = -1, + dict $_query_attributes = dict[], + ): Awaitable { + Logger::log(Verbosity::QUERIES, "SQLFake [verbose]: $query"); + + $config = $this->server->config; + $strict_sql_before = QueryContext::$strictSQLMode; + if ($config['strict_sql_mode'] ?? false) { + QueryContext::$strictSQLMode = true; + } + + $strict_schema_before = QueryContext::$strictSchemaMode; + if ($config['strict_schema_mode'] ?? false) { + QueryContext::$strictSchemaMode = true; + } + + if (($config['inherit_schema_from'] ?? '') !== '') { + $this->dbname = $config['inherit_schema_from'] ?? ''; + } + + try { + list($results, $rows_affected) = SQLCommandProcessor::execute($query, $this); + } catch (\Exception $e) { + // this makes debugging a failing unit test easier, show the actual query that failed parsing along with the parser error + QueryContext::$strictSQLMode = $strict_sql_before; + QueryContext::$strictSchemaMode = $strict_schema_before; + $msg = $e->getMessage(); + $type = \get_class($e); + Logger::log(Verbosity::QUIET, "SQL Fake $type: $msg in SQL query: $query"); + throw $e; + } + QueryContext::$strictSQLMode = $strict_sql_before; + QueryContext::$strictSchemaMode = $strict_schema_before; + Logger::logResult($this->getServer()->name, $results, $rows_affected); + return new AsyncMysqlQueryResult(vec($results), $rows_affected); + } + + <<__Override>> + public async function queryAsync(SQL\Query $query): Awaitable { + return await $this->query($this->queryStringifier->formatQuery($query)); + } + + <<__Override>> + public function queryf(\HH\FormatString<\HH\SQLFormatter> $query, mixed ...$args): Awaitable { + invariant($query is string, '\\HH\\FormatString<_> is opaque, but we need it to be a string here.'); + return $this->query($this->queryStringifier->formatString($query, vec($args))); + } + + <<__Override>> + public async function multiQuery( + Traversable $queries, + int $_timeout_micros = -1, + dict $_query_attributes = dict[], + ): Awaitable> { + $results = await Vec\map_async($queries, $query ==> $this->query($query)); + return Vector::fromItems($results); + } + + <<__Override>> + public function escapeString(string $data): string { + // not actually escaping obviously + return $data; + } + + <<__Override>> + public function close(): void { + $this->open = false; + } + + <<__Override>> + public function releaseConnection(): void {} + + <<__Override>> + public function isValid(): bool { + return $this->open; + } + + <<__Override>> + public function serverInfo(): string { + // Copied from https://docs.hhvm.com/hack/reference/class/AsyncMysqlConnection/serverInfo/ + return '5.6.24-fb-log-slackhq-sql-fake'; + } + + <<__Override>> + public function warningCount(): int { + return 0; + } + + <<__Override>> + public function host(): string { + return $this->host; + } + + <<__Override>> + public function port(): int { + return $this->port; + } + + <<__Override>> + public function setReusable(bool $reusable): void { + $this->reusable = $reusable; + } + + <<__Override>> + public function isReusable(): bool { + return $this->reusable; + } + + <<__Override>> + public function lastActivityTime(): float { + // A float representing the number of seconds ago since epoch that we had successful activity on the current connection. + // 50 ms ago seems like a reasonable answer. + return \microtime(true) - 0.05; + } + + <<__Override>> + public function connectResult(): \AsyncMysqlConnectResult { + return $this->result; + } } diff --git a/src/AsyncMysql/AsyncMysqlConnectionPool.php b/src/AsyncMysql/AsyncMysqlConnectionPool.php index 9d94414..4913e03 100644 --- a/src/AsyncMysql/AsyncMysqlConnectionPool.php +++ b/src/AsyncMysql/AsyncMysqlConnectionPool.php @@ -11,70 +11,70 @@ */ final class AsyncMysqlConnectionPool extends \AsyncMysqlConnectionPool { - private int $createdPoolConnections = 0; - private int $destroyedPoolConnections = 0; - private int $connectionsRequest = 0; - private int $poolHits = 0; - private int $poolMisses = 0; - private static dict $pool = dict[]; + private int $createdPoolConnections = 0; + private int $destroyedPoolConnections = 0; + private int $connectionsRequest = 0; + private int $poolHits = 0; + private int $poolMisses = 0; + private static dict $pool = dict[]; - public function reset(): void { - static::$pool = dict[]; - } + public function reset(): void { + static::$pool = dict[]; + } - <<__Override>> - /* HH_FIXME[4341] temp fix signature changed */ - public async function connect( - string $host, - int $port, - string $dbname, - string $_user, - string $_password, - int $_timeout_micros = -1, - string $_caller = '', - ?\MySSLContextProvider $_ssl_context = null, - int $_tcp_timeout_micros = 0, + <<__Override>> + /* HH_FIXME[4341] temp fix signature changed */ + public async function connect( + string $host, + int $port, + string $dbname, + string $_user, + string $_password, + int $_timeout_micros = -1, + string $_caller = '', + ?\MySSLContextProvider $_ssl_context = null, + int $_tcp_timeout_micros = 0, string $_sni_server_name = '', string $_server_cert_extensions = '', string $_server_cert_values = '', - ): Awaitable { - $this->connectionsRequest++; - if (C\contains_key(static::$pool, $host)) { - $this->poolHits++; - $conn = static::$pool[$host]; - $conn->setDatabase($dbname); - return $conn; - } + ): Awaitable { + $this->connectionsRequest++; + if (C\contains_key(static::$pool, $host)) { + $this->poolHits++; + $conn = static::$pool[$host]; + $conn->setDatabase($dbname); + return $conn; + } - $this->poolMisses++; - $this->createdPoolConnections++; - $conn = new AsyncMysqlConnection($host, $port, $dbname); - static::$pool[$host] = $conn; - return $conn; - } + $this->poolMisses++; + $this->createdPoolConnections++; + $conn = new AsyncMysqlConnection($host, $port, $dbname); + static::$pool[$host] = $conn; + return $conn; + } - <<__Override>> - public function connectWithOpts( - string $host, - int $port, - string $dbname, - string $user, - string $password, - \AsyncMysqlConnectionOptions $_conn_opts, - string $caller = '', - ): Awaitable<\AsyncMysqlConnection> { - // currently, options are ignored in SQLFake - return $this->connect($host, $port, $dbname, $user, $password, -1, $caller); - } + <<__Override>> + public function connectWithOpts( + string $host, + int $port, + string $dbname, + string $user, + string $password, + \AsyncMysqlConnectionOptions $_conn_opts, + string $caller = '', + ): Awaitable<\AsyncMysqlConnection> { + // currently, options are ignored in SQLFake + return $this->connect($host, $port, $dbname, $user, $password, -1, $caller); + } - <<__Override>> - public function getPoolStats(): darray { - return darray[ - 'created_pool_connections' => $this->createdPoolConnections, - 'destroyed_pool_connections' => $this->destroyedPoolConnections, - 'connections_request' => $this->connectionsRequest, - 'pool_hits' => $this->poolHits, - 'pool_misses' => $this->poolMisses, - ]; - } + <<__Override>> + public function getPoolStats(): darray { + return darray[ + 'created_pool_connections' => $this->createdPoolConnections, + 'destroyed_pool_connections' => $this->destroyedPoolConnections, + 'connections_request' => $this->connectionsRequest, + 'pool_hits' => $this->poolHits, + 'pool_misses' => $this->poolMisses, + ]; + } } diff --git a/src/AsyncMysql/AsyncMysqlQueryResult.php b/src/AsyncMysql/AsyncMysqlQueryResult.php index 0776bed..f78f53c 100644 --- a/src/AsyncMysql/AsyncMysqlQueryResult.php +++ b/src/AsyncMysql/AsyncMysqlQueryResult.php @@ -9,102 +9,102 @@ <<__MockClass>> final class AsyncMysqlQueryResult extends \AsyncMysqlQueryResult { - /* HH_IGNORE_ERROR[3012] I don't want to call parent::construct */ - public function __construct(private dataset $rows, private int $rows_affected = 0, private int $last_insert_id = 0) {} - - public function rows(): dataset { - return $this->rows; - } - - <<__Override>> - public function numRowsAffected(): int { - return $this->rows_affected; - } - - <<__Override>> - public function lastInsertId(): int { - return $this->last_insert_id; - } - - <<__Override>> - public function numRows(): int { - return C\count($this->rows); - } - - <<__Override>> - public function mapRows(): Vector> { - $out = Vector {}; - foreach ($this->rows as $row) { - $map = Map {}; - foreach ($row as $column => $value) { - // in the untyped version, all columns are `?string` - $map->set($column, $value is nonnull ? (string)$value : null); - } - $out->add($map); - } - return $out; - } - - <<__Override>> - public function dictRowsTyped(): vec> { - return vec($this->rows); - } - - <<__Override>> - public function mapRowsTyped(): Vector> { - $out = Vector {}; - foreach ($this->rows as $row) { - $out->add(new Map($row)); - } - return $out; - } - - <<__Override>> - public function vectorRows(): Vector> { - $out = Vector {}; - foreach ($this->rows as $row) { - $v = Vector {}; - foreach ($row as $value) { - // in the untyped version, all columns are `?string` - $v->add($value is nonnull ? (string)$value : null); - } - $out->add($v); - } - return $out; - } - - <<__Override>> - public function vectorRowsTyped(): Vector> { - $out = Vector {}; - foreach ($this->rows as $row) { - $v = Vector {}; - foreach ($row as $value) { - $v->add($value); - } - $out->add($v); - } - return $out; - } - - <<__Override>> - public function rowBlocks(): mixed { - throw new SQLFakeNotImplementedException('row blocks not implemented'); - } - - <<__Override>> - public function noIndexUsed(): bool { - // TODO: it would be really interesting to actually try to determine if a query could use an index - // and set this value so that this could be instrumented in tests - return true; - } - - <<__Override>> - public function recvGtid(): string { - return 'stubbed'; - } - - <<__Override>> - public function elapsedMicros(): int { - return 100; - } + /* HH_IGNORE_ERROR[3012] I don't want to call parent::construct */ + public function __construct(private dataset $rows, private int $rows_affected = 0, private int $last_insert_id = 0) {} + + public function rows(): dataset { + return $this->rows; + } + + <<__Override>> + public function numRowsAffected(): int { + return $this->rows_affected; + } + + <<__Override>> + public function lastInsertId(): int { + return $this->last_insert_id; + } + + <<__Override>> + public function numRows(): int { + return C\count($this->rows); + } + + <<__Override>> + public function mapRows(): Vector> { + $out = Vector {}; + foreach ($this->rows as $row) { + $map = Map {}; + foreach ($row as $column => $value) { + // in the untyped version, all columns are `?string` + $map->set($column, $value is nonnull ? (string)$value : null); + } + $out->add($map); + } + return $out; + } + + <<__Override>> + public function dictRowsTyped(): vec> { + return vec($this->rows); + } + + <<__Override>> + public function mapRowsTyped(): Vector> { + $out = Vector {}; + foreach ($this->rows as $row) { + $out->add(new Map($row)); + } + return $out; + } + + <<__Override>> + public function vectorRows(): Vector> { + $out = Vector {}; + foreach ($this->rows as $row) { + $v = Vector {}; + foreach ($row as $value) { + // in the untyped version, all columns are `?string` + $v->add($value is nonnull ? (string)$value : null); + } + $out->add($v); + } + return $out; + } + + <<__Override>> + public function vectorRowsTyped(): Vector> { + $out = Vector {}; + foreach ($this->rows as $row) { + $v = Vector {}; + foreach ($row as $value) { + $v->add($value); + } + $out->add($v); + } + return $out; + } + + <<__Override>> + public function rowBlocks(): mixed { + throw new SQLFakeNotImplementedException('row blocks not implemented'); + } + + <<__Override>> + public function noIndexUsed(): bool { + // TODO: it would be really interesting to actually try to determine if a query could use an index + // and set this value so that this could be instrumented in tests + return true; + } + + <<__Override>> + public function recvGtid(): string { + return 'stubbed'; + } + + <<__Override>> + public function elapsedMicros(): int { + return 100; + } } diff --git a/src/BuildSchemaCLI.php b/src/BuildSchemaCLI.php index 2b50373..dc77193 100644 --- a/src/BuildSchemaCLI.php +++ b/src/BuildSchemaCLI.php @@ -77,9 +77,8 @@ protected function getSupportedOptions(): vec { HackBuilderValues::shapeWithPerKeyRendering( shape( 'name' => HackBuilderValues::export(), - 'indexes' => HackBuilderValues::vec( - HackBuilderValues::shapeWithUniformRendering(HackBuilderValues::export()), - ), + 'indexes' => + HackBuilderValues::vec(HackBuilderValues::shapeWithUniformRendering(HackBuilderValues::export())), 'fields' => HackBuilderValues::vec(HackBuilderValues::shapeWithPerKeyRendering( shape( 'name' => HackBuilderValues::export(), diff --git a/src/DataIntegrity.php b/src/DataIntegrity.php index 6ff56e0..6758f13 100644 --- a/src/DataIntegrity.php +++ b/src/DataIntegrity.php @@ -14,301 +14,316 @@ */ abstract final class DataIntegrity { - <<__Memoize>> - public static function namesForSchema(table_schema $schema): keyset { - return Keyset\map($schema['fields'], $field ==> $field['name']); - } + <<__Memoize>> + public static function namesForSchema(table_schema $schema): keyset { + return Keyset\map($schema['fields'], $field ==> $field['name']); + } - protected static function getDefaultValueForField( - string $field_type, - bool $nullable, - ?string $default, - string $field_name, - string $table_name, - ): mixed { + protected static function getDefaultValueForField( + string $field_type, + bool $nullable, + ?string $default, + string $field_name, + string $table_name, + ): mixed { - if ($default !== null) { - switch ($field_type) { - case 'int': - return Str\to_int($default); - break; - case 'double': - return (float)$default; - break; - default: - return $default; - break; - } - } else if ($nullable) { - return null; - } + if ($default !== null) { + switch ($field_type) { + case 'int': + return Str\to_int($default); + break; + case 'double': + return (float)$default; + break; + default: + return $default; + break; + } + } else if ($nullable) { + return null; + } - if (QueryContext::$strictSQLMode) { - // if we got this far the column has no default and isn't nullable, strict would throw - // but default MySQL mode would coerce to a valid value - throw new SQLFakeRuntimeException("Column '{$field_name}' on '{$table_name}' does not allow null values"); - } + if (QueryContext::$strictSQLMode) { + // if we got this far the column has no default and isn't nullable, strict would throw + // but default MySQL mode would coerce to a valid value + throw new SQLFakeRuntimeException("Column '{$field_name}' on '{$table_name}' does not allow null values"); + } - switch ($field_type) { - case 'int': - return 0; - break; - case 'double': - return 0.0; - break; - default: - return ''; - break; - } - } + switch ($field_type) { + case 'int': + return 0; + break; + case 'double': + return 0.0; + break; + default: + return ''; + break; + } + } - /** - * Ensure all fields from the table schema are present in the row - * Applies default values based on either DEFAULTs, nullable fields, or data types - */ - public static function ensureFieldsPresent(dict $row, table_schema $schema): dict { + /** + * Ensure all fields from the table schema are present in the row + * Applies default values based on either DEFAULTs, nullable fields, or data types + */ + public static function ensureFieldsPresent(dict $row, table_schema $schema): dict { - foreach ($schema['fields'] as $field) { - $field_name = $field['name']; - $field_type = $field['hack_type']; - $field_length = $field['length']; - $field_mysql_type = $field['type']; - $field_nullable = $field['null'] ?? false; - $field_default = $field['default'] ?? null; - $field_unsigned = $field['unsigned'] ?? false; + foreach ($schema['fields'] as $field) { + $field_name = $field['name']; + $field_type = $field['hack_type']; + $field_length = $field['length']; + $field_mysql_type = $field['type']; + $field_nullable = $field['null'] ?? false; + $field_default = $field['default'] ?? null; + $field_unsigned = $field['unsigned'] ?? false; - if (!C\contains_key($row, $field_name)) { - $row[$field_name] = - self::getDefaultValueForField($field_type, $field_nullable, $field_default, $field_name, $schema['name']); - } else if ($row[$field_name] === null) { - if ($field_nullable) { - // explicit null value and nulls are allowed, let it through - continue; - } else if (QueryContext::$strictSQLMode) { - // if we got this far the column has no default and isn't nullable, strict would throw - // but default MySQL mode would coerce to a valid value - throw new SQLFakeRuntimeException("Column '{$field_name}' on '{$schema['name']}' does not allow null values"); - } else { - $row[$field_name] = - self::getDefaultValueForField($field_type, $field_nullable, $field_default, $field_name, $schema['name']); - } - } else { - // TODO more integrity constraints, check field length for varchars, check timestamps - switch ($field_type) { - case 'int': - if ($row[$field_name] is bool) { - $row[$field_name] = (int)$row[$field_name]; - } else if (!$row[$field_name] is int) { - if (QueryContext::$strictSQLMode) { - $field_str = \var_export($row[$field_name], true); - throw new SQLFakeRuntimeException( - "Invalid value {$field_str} for column '{$field_name}' on '{$schema['name']}', expected int", - ); - } else { - $row[$field_name] = (int)$row[$field_name]; - } - } else { - $signed = !($field_unsigned); - $field_value = (int)$row[$field_name]; + if (!C\contains_key($row, $field_name)) { + $row[$field_name] = + self::getDefaultValueForField($field_type, $field_nullable, $field_default, $field_name, $schema['name']); + } else if ($row[$field_name] === null) { + if ($field_nullable) { + // explicit null value and nulls are allowed, let it through + continue; + } else if (QueryContext::$strictSQLMode) { + // if we got this far the column has no default and isn't nullable, strict would throw + // but default MySQL mode would coerce to a valid value + throw new SQLFakeRuntimeException("Column '{$field_name}' on '{$schema['name']}' does not allow null values"); + } else { + $row[$field_name] = + self::getDefaultValueForField($field_type, $field_nullable, $field_default, $field_name, $schema['name']); + } + } else { + // TODO more integrity constraints, check field length for varchars, check timestamps + switch ($field_type) { + case 'int': + if ($row[$field_name] is bool) { + $row[$field_name] = (int)$row[$field_name]; + } else if (!$row[$field_name] is int) { + if (QueryContext::$strictSQLMode) { + $field_str = \var_export($row[$field_name], true); + throw new SQLFakeRuntimeException( + "Invalid value {$field_str} for column '{$field_name}' on '{$schema['name']}', expected int", + ); + } else { + $row[$field_name] = (int)$row[$field_name]; + } + } else { + $signed = !($field_unsigned); + $field_value = (int)$row[$field_name]; - switch($field_mysql_type) { - case DataType::TINYINT: - if ($field_value < (($signed) ? -\pow(2,7) : 0) || $field_value >= (($signed) ? \pow(2,7) : \pow(2,8))){ - throw new SQLFakeRuntimeException( - "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", - ); - } - break; - case DataType::SMALLINT: - if ($field_value < (($signed) ? -\pow(2,15) : 0) || $field_value >= (($signed) ? \pow(2,15) : \pow(2,16))){ - throw new SQLFakeRuntimeException( - "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", - ); - } - break; - case DataType::MEDIUMINT: - if ($field_value < (($signed) ? -\pow(2,23) : 0) || $field_value >= (($signed) ? \pow(2,23) : \pow(2,24))){ - throw new SQLFakeRuntimeException( - "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", - ); - } - break; - case DataType::INT: - if ($field_value < (($signed) ? -\pow(2,31) : 0) || $field_value >= (($signed) ? \pow(2,31) : \pow(2,32))){ - throw new SQLFakeRuntimeException( - "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", - ); - } - break; - case DataType::BIGINT: - if ($field_value < (($signed) ? -\pow(2,63) : 0) || $field_value >= (($signed) ? \pow(2,63) : \pow(2,64))){ - throw new SQLFakeRuntimeException( - "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", - ); - } - break; - default: - throw new SQLFakeRuntimeException( - "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", - ); - break; - } - } - break; - case 'double': - if (!$row[$field_name] is float) { - if (QueryContext::$strictSQLMode) { - $field_str = \var_export($row[$field_name], true); - throw new SQLFakeRuntimeException( - "Invalid value '{$field_str}' for column '{$field_name}' on '{$schema['name']}', expected float", - ); - } else { - $row[$field_name] = (float)$row[$field_name]; - } - } - break; - default: - if (!$row[$field_name] is string) { - if (QueryContext::$strictSQLMode) { - $field_str = \var_export($row[$field_name], true); - throw new SQLFakeRuntimeException( - "Invalid value '{$field_str}' for column '{$field_name}' on '{$schema['name']}', expected string", - ); - } else { - $row[$field_name] = (string)$row[$field_name]; - } - } else { - $field_value = (string)$row[$field_name]; - // handle json column type validation - if ($field_mysql_type === DataType::JSON) { - // null is okay - if ($row[$field_name] is nonnull) { - if (!Str\is_empty((string)$row[$field_name])) { - // validate json string - $json_obj = \json_decode((string)$row[$field_name]); - if ($json_obj is null) { - // MySQL will accept the string 'null' in a json column and it converts it to a proper NULL - // the string 'null', however, returns NULL when decoded via \json_decode() which is the same - // as what we get from decoding invalid json - if ((string)$row[$field_name] === 'null') { - $row[$field_name] = null; - $field_value = null; - } else { - // invalid json - throw new SQLFakeRuntimeException( - "Invalid value '{$field_value}' for column '{$field_name}' on '{$schema['name']}', expected json", - ); - } - } - } else { - // empty strings are not valid for json columns - throw new SQLFakeRuntimeException( - "Invalid value '{$field_value}' for column '{$field_name}' on '{$schema['name']}', expected json", - ); - } - } - } else if ($field_length > 0 && \mb_strlen($field_value) > $field_length) { - $field_str = \var_export($row[$field_name], true); - throw new SQLFakeRuntimeException( - "Invalid value '{$field_str}' for column '{$field_name}' on '{$schema['name']}', expected string of size {$field_length}", - ); - } - } - break; - } - } - } + switch ($field_mysql_type) { + case DataType::TINYINT: + if ( + $field_value < (($signed) ? -\pow(2, 7) : 0) || + $field_value >= (($signed) ? \pow(2, 7) : \pow(2, 8)) + ) { + throw new SQLFakeRuntimeException( + "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", + ); + } + break; + case DataType::SMALLINT: + if ( + $field_value < (($signed) ? -\pow(2, 15) : 0) || + $field_value >= (($signed) ? \pow(2, 15) : \pow(2, 16)) + ) { + throw new SQLFakeRuntimeException( + "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", + ); + } + break; + case DataType::MEDIUMINT: + if ( + $field_value < (($signed) ? -\pow(2, 23) : 0) || + $field_value >= (($signed) ? \pow(2, 23) : \pow(2, 24)) + ) { + throw new SQLFakeRuntimeException( + "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", + ); + } + break; + case DataType::INT: + if ( + $field_value < (($signed) ? -\pow(2, 31) : 0) || + $field_value >= (($signed) ? \pow(2, 31) : \pow(2, 32)) + ) { + throw new SQLFakeRuntimeException( + "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", + ); + } + break; + case DataType::BIGINT: + if ( + $field_value < (($signed) ? -\pow(2, 63) : 0) || + $field_value >= (($signed) ? \pow(2, 63) : \pow(2, 64)) + ) { + throw new SQLFakeRuntimeException( + "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", + ); + } + break; + default: + throw new SQLFakeRuntimeException( + "Column '{$field_name}' on '{$schema['name']}' expects a valid '{$field_mysql_type}'", + ); + break; + } + } + break; + case 'double': + if (!$row[$field_name] is float) { + if (QueryContext::$strictSQLMode) { + $field_str = \var_export($row[$field_name], true); + throw new SQLFakeRuntimeException( + "Invalid value '{$field_str}' for column '{$field_name}' on '{$schema['name']}', expected float", + ); + } else { + $row[$field_name] = (float)$row[$field_name]; + } + } + break; + default: + if (!$row[$field_name] is string) { + if (QueryContext::$strictSQLMode) { + $field_str = \var_export($row[$field_name], true); + throw new SQLFakeRuntimeException( + "Invalid value '{$field_str}' for column '{$field_name}' on '{$schema['name']}', expected string", + ); + } else { + $row[$field_name] = (string)$row[$field_name]; + } + } else { + $field_value = (string)$row[$field_name]; + // handle json column type validation + if ($field_mysql_type === DataType::JSON) { + // null is okay + if ($row[$field_name] is nonnull) { + if (!Str\is_empty((string)$row[$field_name])) { + // validate json string + $json_obj = \json_decode((string)$row[$field_name]); + if ($json_obj is null) { + // MySQL will accept the string 'null' in a json column and it converts it to a proper NULL + // the string 'null', however, returns NULL when decoded via \json_decode() which is the same + // as what we get from decoding invalid json + if ((string)$row[$field_name] === 'null') { + $row[$field_name] = null; + $field_value = null; + } else { + // invalid json + throw new SQLFakeRuntimeException( + "Invalid value '{$field_value}' for column '{$field_name}' on '{$schema['name']}', expected json", + ); + } + } + } else { + // empty strings are not valid for json columns + throw new SQLFakeRuntimeException( + "Invalid value '{$field_value}' for column '{$field_name}' on '{$schema['name']}', expected json", + ); + } + } + } else if ($field_length > 0 && \mb_strlen($field_value) > $field_length) { + $field_str = \var_export($row[$field_name], true); + throw new SQLFakeRuntimeException( + "Invalid value '{$field_str}' for column '{$field_name}' on '{$schema['name']}', expected string of size {$field_length}", + ); + } + } + break; + } + } + } - return $row; - } + return $row; + } - /** - * Ensure default values are present, coerce data types as MySQL would - */ - public static function coerceToSchema(dict $row, table_schema $schema): dict { + /** + * Ensure default values are present, coerce data types as MySQL would + */ + public static function coerceToSchema(dict $row, table_schema $schema): dict { - $fields = self::namesForSchema($schema); - $bad_fields = Keyset\keys($row) |> Keyset\diff($$, $fields); - if (!C\is_empty($bad_fields)) { - $bad_fields = Str\join($bad_fields, ', '); - throw new SQLFakeRuntimeException("Column(s) '{$bad_fields}' not found on '{$schema['name']}'"); - } + $fields = self::namesForSchema($schema); + $bad_fields = Keyset\keys($row) |> Keyset\diff($$, $fields); + if (!C\is_empty($bad_fields)) { + $bad_fields = Str\join($bad_fields, ', '); + throw new SQLFakeRuntimeException("Column(s) '{$bad_fields}' not found on '{$schema['name']}'"); + } - $row = self::ensureFieldsPresent($row, $schema); + $row = self::ensureFieldsPresent($row, $schema); - foreach ($schema['fields'] as $field) { - $field_name = $field['name']; - $field_type = $field['hack_type']; + foreach ($schema['fields'] as $field) { + $field_name = $field['name']; + $field_type = $field['hack_type']; - // don't coerce null values on nullable fields - if ($field['null'] && $row[$field_name] === null) { - continue; - } + // don't coerce null values on nullable fields + if ($field['null'] && $row[$field_name] === null) { + continue; + } - switch ($field_type) { - case 'int': - $row[$field_name] = (int)$row[$field_name]; - break; - case 'string': - $row[$field_name] = (string)$row[$field_name]; - break; - case 'double': - case 'float': - $row[$field_name] = (float)$row[$field_name]; - break; - default: - throw new SQLFakeRuntimeException( - "DataIntegrity::coerceToSchema found unknown type for field: '{$field_name}:{$field_type}'", - ); - } - } + switch ($field_type) { + case 'int': + $row[$field_name] = (int)$row[$field_name]; + break; + case 'string': + $row[$field_name] = (string)$row[$field_name]; + break; + case 'double': + case 'float': + $row[$field_name] = (float)$row[$field_name]; + break; + default: + throw new SQLFakeRuntimeException( + "DataIntegrity::coerceToSchema found unknown type for field: '{$field_name}:{$field_type}'", + ); + } + } - return $row; - } + return $row; + } - /** - * Check for unique key violations - * If there's a violation, this returns a string message, as well as the integer id of the row that conflicted - * Caller may decide to throw using the message, or make use of the row id to do an update - */ - public static function checkUniqueConstraints( - dataset $table, - dict $row, - table_schema $schema, - ?int $update_row_id = null, - ): ?(string, int) { + /** + * Check for unique key violations + * If there's a violation, this returns a string message, as well as the integer id of the row that conflicted + * Caller may decide to throw using the message, or make use of the row id to do an update + */ + public static function checkUniqueConstraints( + dataset $table, + dict $row, + table_schema $schema, + ?int $update_row_id = null, + ): ?(string, int) { - // gather all unique keys - $unique_keys = dict[]; - foreach ($schema['indexes'] as $index) { - if ($index['type'] === 'PRIMARY') { - $unique_keys['PRIMARY'] = keyset($index['fields']); - } else if ($index['type'] === 'UNIQUE') { - $unique_keys[$index['name']] = keyset($index['fields']); - } - } + // gather all unique keys + $unique_keys = dict[]; + foreach ($schema['indexes'] as $index) { + if ($index['type'] === 'PRIMARY') { + $unique_keys['PRIMARY'] = keyset($index['fields']); + } else if ($index['type'] === 'UNIQUE') { + $unique_keys[$index['name']] = keyset($index['fields']); + } + } - foreach ($unique_keys as $name => $unique_key) { - // unique key that allows nullable fields? if any of this key's fields on our candidate row are null, skip this key - // primary keys don't ever allow this - if ($name !== 'PRIMARY' && C\any($unique_key, $key ==> $row[$key] === null)) { - continue; - } + foreach ($unique_keys as $name => $unique_key) { + // unique key that allows nullable fields? if any of this key's fields on our candidate row are null, skip this key + // primary keys don't ever allow this + if ($name !== 'PRIMARY' && C\any($unique_key, $key ==> $row[$key] === null)) { + continue; + } - // are there any existing rows in the table for which every unique key field matches this row? - foreach ($table as $row_id => $r) { - // if we're updating and this is the row from the original table that we're updating, don't check that one - if ($row_id === $update_row_id) { - continue; - } - if (C\every($unique_key, $field ==> $r[$field] === $row[$field])) { - $dupe_unique_key_value = Vec\map($unique_key, $field ==> (string)$row[$field]) |> Str\join($$, ', '); - return - tuple("Duplicate entry '{$dupe_unique_key_value}' for key '{$name}' in table '{$schema['name']}'", $row_id); - } - } - } + // are there any existing rows in the table for which every unique key field matches this row? + foreach ($table as $row_id => $r) { + // if we're updating and this is the row from the original table that we're updating, don't check that one + if ($row_id === $update_row_id) { + continue; + } + if (C\every($unique_key, $field ==> $r[$field] === $row[$field])) { + $dupe_unique_key_value = Vec\map($unique_key, $field ==> (string)$row[$field]) |> Str\join($$, ', '); + return + tuple("Duplicate entry '{$dupe_unique_key_value}' for key '{$name}' in table '{$schema['name']}'", $row_id); + } + } + } - return null; - } + return null; + } } diff --git a/src/Expressions/BaseFunctionExpression.hack b/src/Expressions/BaseFunctionExpression.hack index b35d3ef..1316072 100644 --- a/src/Expressions/BaseFunctionExpression.hack +++ b/src/Expressions/BaseFunctionExpression.hack @@ -7,47 +7,47 @@ use namespace HH\Lib\C; */ abstract class BaseFunctionExpression extends Expression { - protected string $functionName; - protected bool $evaluatesGroups = true; - - public function __construct(private token $token, protected vec $args, protected bool $distinct) { - $this->type = $token['type']; - $this->precedence = 0; - $this->functionName = $token['value']; - $this->name = $token['value']; - /*HH_FIXME[4110] Open issue #24 should resolve what to do here.*/ - $this->operator = (string)$this->type; - } - - public function functionName(): string { - return $this->functionName; - } - - <<__Override>> - public function isWellFormed(): bool { - return true; - } - - /** - * helper for functions which take one expression as an argument - */ - protected function getExpr(): Expression { - invariant(C\count($this->args) === 1, 'expression must have one argument'); - return C\firstx($this->args); - } - - <<__Override>> - public function __debugInfo(): dict { - $args = vec[]; - foreach ($this->args as $arg) { - $args[] = \var_dump($arg, true); - } - return dict[ - 'type' => (string)$this->type, - 'functionName' => $this->functionName, - 'args' => $args, - 'name' => $this->name, - 'distinct' => $this->distinct, - ]; - } + protected string $functionName; + protected bool $evaluatesGroups = true; + + public function __construct(private token $token, protected vec $args, protected bool $distinct) { + $this->type = $token['type']; + $this->precedence = 0; + $this->functionName = $token['value']; + $this->name = $token['value']; + /*HH_FIXME[4110] Open issue #24 should resolve what to do here.*/ + $this->operator = (string)$this->type; + } + + public function functionName(): string { + return $this->functionName; + } + + <<__Override>> + public function isWellFormed(): bool { + return true; + } + + /** + * helper for functions which take one expression as an argument + */ + protected function getExpr(): Expression { + invariant(C\count($this->args) === 1, 'expression must have one argument'); + return C\firstx($this->args); + } + + <<__Override>> + public function __debugInfo(): dict { + $args = vec[]; + foreach ($this->args as $arg) { + $args[] = \var_dump($arg, true); + } + return dict[ + 'type' => (string)$this->type, + 'functionName' => $this->functionName, + 'args' => $args, + 'name' => $this->name, + 'distinct' => $this->distinct, + ]; + } } diff --git a/src/Expressions/BetweenOperatorExpression.php b/src/Expressions/BetweenOperatorExpression.php index f003c68..0c8a0e3 100644 --- a/src/Expressions/BetweenOperatorExpression.php +++ b/src/Expressions/BetweenOperatorExpression.php @@ -7,136 +7,136 @@ */ final class BetweenOperatorExpression extends Expression { - private ?Expression $start = null; - private ?Expression $end = null; - private bool $and = false; - protected bool $evaluates_groups = false; - public bool $negated = false; - - public function __construct(private Expression $left) { - $op = Operator::BETWEEN; - $this->name = ''; - $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($op)]; - $this->operator = $op; - $this->type = TokenType::OPERATOR; - } - - <<__Override>> - public function evaluateImpl(row $row, AsyncMysqlConnection $conn): bool { - $start = $this->start; - $end = $this->end; - if ($start === null || $end === null) { - throw new SQLFakeRuntimeException('Attempted to evaluate incomplete BETWEEN expression'); - } - - // any part of the between clause could be a column or a literal, so check each one - $subject = $this->left->evaluate($row, $conn); - $start = $start->evaluate($row, $conn); - $end = $end->evaluate($row, $conn); - - // between clause is lower and upper inclusive - if ($subject is num) { - $subject = (int)$subject; - $start = (int)$start; - $end = (int)$end; - $eval = $subject >= $start && $subject <= $end; - } else { - $subject = (string)$subject; - $start = (string)$start; - $end = (string)$end; - $eval = $subject >= $start && $subject <= $end; - } - - return ($this->negated ? !$eval : $eval) ? true : false; - } - - <<__Override>> - public function negate(): void { - $this->negated = true; - } - - <<__Override>> - public function isWellFormed(): bool { - return $this->start && $this->end; - } - - public function setStart(Expression $expr): void { - $this->start = $expr; - } - - public function setEnd(Expression $expr): void { - $this->end = $expr; - } - - public function foundAnd(): void { - if ($this->and || !$this->start) { - throw new SQLFakeParseException('Unexpected AND'); - } - $this->and = true; - } - - <<__Override>> - public function setNextChild(Expression $expr, bool $overwrite = false): void { - if ($overwrite) { - // this mode is when we come out of a recursive expression where we had to pull out the most recent token, so overwrite that expression with the result - if ($this->end) { - $this->end = $expr; - } else if ($this->start) { - $this->start = $expr; - } else { - $this->left = $expr; - } - return; - } - - if (!$this->start) { - $this->start = $expr; - } else if ($this->and && !$this->end) { - $this->end = $expr; - } else { - throw new SQLFakeParseException('Parse error: unexpected token in BETWEEN statement'); - } - } - - private function getLatestExpression(): Expression { - if ($this->end) { - return $this->end; - } - if ($this->start) { - return $this->start; - } - return $this->left; - } - - <<__Override>> - public function addRecursiveExpression(token_list $tokens, int $pointer, bool $negated = false): int { - - $tmp = new BinaryOperatorExpression($this->getLatestExpression()); - - $p = new ExpressionParser($tokens, $pointer, $tmp, $this->precedence, /* $is_child */ true); - list($pointer, $new_expression) = $p->buildWithPointer(); - - if ($negated) { - $new_expression->negate(); - } - - $this->setNextChild($new_expression, true); - - return $pointer; - } - - <<__Override>> - public function __debugInfo(): dict { - $ret = dict[ - 'type' => (string)$this->type, - 'left' => \var_dump($this->left, true), - 'start' => $this->start ? \var_dump($this->start, true) : dict[], - 'end' => $this->end ? \var_dump($this->end, true) : dict[], - ]; - - if ((bool)$this->name) { - $ret['name'] = $this->name; - } - return $ret; - } + private ?Expression $start = null; + private ?Expression $end = null; + private bool $and = false; + protected bool $evaluates_groups = false; + public bool $negated = false; + + public function __construct(private Expression $left) { + $op = Operator::BETWEEN; + $this->name = ''; + $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($op)]; + $this->operator = $op; + $this->type = TokenType::OPERATOR; + } + + <<__Override>> + public function evaluateImpl(row $row, AsyncMysqlConnection $conn): bool { + $start = $this->start; + $end = $this->end; + if ($start === null || $end === null) { + throw new SQLFakeRuntimeException('Attempted to evaluate incomplete BETWEEN expression'); + } + + // any part of the between clause could be a column or a literal, so check each one + $subject = $this->left->evaluate($row, $conn); + $start = $start->evaluate($row, $conn); + $end = $end->evaluate($row, $conn); + + // between clause is lower and upper inclusive + if ($subject is num) { + $subject = (int)$subject; + $start = (int)$start; + $end = (int)$end; + $eval = $subject >= $start && $subject <= $end; + } else { + $subject = (string)$subject; + $start = (string)$start; + $end = (string)$end; + $eval = $subject >= $start && $subject <= $end; + } + + return ($this->negated ? !$eval : $eval) ? true : false; + } + + <<__Override>> + public function negate(): void { + $this->negated = true; + } + + <<__Override>> + public function isWellFormed(): bool { + return $this->start && $this->end; + } + + public function setStart(Expression $expr): void { + $this->start = $expr; + } + + public function setEnd(Expression $expr): void { + $this->end = $expr; + } + + public function foundAnd(): void { + if ($this->and || !$this->start) { + throw new SQLFakeParseException('Unexpected AND'); + } + $this->and = true; + } + + <<__Override>> + public function setNextChild(Expression $expr, bool $overwrite = false): void { + if ($overwrite) { + // this mode is when we come out of a recursive expression where we had to pull out the most recent token, so overwrite that expression with the result + if ($this->end) { + $this->end = $expr; + } else if ($this->start) { + $this->start = $expr; + } else { + $this->left = $expr; + } + return; + } + + if (!$this->start) { + $this->start = $expr; + } else if ($this->and && !$this->end) { + $this->end = $expr; + } else { + throw new SQLFakeParseException('Parse error: unexpected token in BETWEEN statement'); + } + } + + private function getLatestExpression(): Expression { + if ($this->end) { + return $this->end; + } + if ($this->start) { + return $this->start; + } + return $this->left; + } + + <<__Override>> + public function addRecursiveExpression(token_list $tokens, int $pointer, bool $negated = false): int { + + $tmp = new BinaryOperatorExpression($this->getLatestExpression()); + + $p = new ExpressionParser($tokens, $pointer, $tmp, $this->precedence, /* $is_child */ true); + list($pointer, $new_expression) = $p->buildWithPointer(); + + if ($negated) { + $new_expression->negate(); + } + + $this->setNextChild($new_expression, true); + + return $pointer; + } + + <<__Override>> + public function __debugInfo(): dict { + $ret = dict[ + 'type' => (string)$this->type, + 'left' => \var_dump($this->left, true), + 'start' => $this->start ? \var_dump($this->start, true) : dict[], + 'end' => $this->end ? \var_dump($this->end, true) : dict[], + ]; + + if ((bool)$this->name) { + $ret['name'] = $this->name; + } + return $ret; + } } diff --git a/src/Expressions/BinaryOperatorExpression.php b/src/Expressions/BinaryOperatorExpression.php index d32a76e..beb49cb 100644 --- a/src/Expressions/BinaryOperatorExpression.php +++ b/src/Expressions/BinaryOperatorExpression.php @@ -11,402 +11,402 @@ */ final class BinaryOperatorExpression extends Expression { - protected bool $evaluates_groups = false; - protected int $negatedInt = 0; - - public function __construct( - public Expression $left, // public because we sometimes need to access it to split off into a BETWEEN - public bool $negated = false, - public ?Operator $operator = null, - public ?Expression $right = null, - ) { - $this->name = ''; - // this gets overwritten once we have an operator - $this->precedence = 0; - $this->type = TokenType::OPERATOR; - if ($operator is nonnull) { - $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($operator)]; - } - - $this->negatedInt = $this->negated ? 1 : 0; - } - - /** - * Runs the comparison on each element between the left and right, - * BUT if the values are equal it keeps checking down the list - * (1, 2, 3) > (1, 2, 2) for example - * (1, 2, 3) > (1, 1, 4) is also true - */ - private function evaluateRowComparison( - RowExpression $left, - RowExpression $right, - row $row, - AsyncMysqlConnection $conn, - ): bool { - - $left_elems = $left->evaluate($row, $conn); - invariant($left_elems is vec<_>, 'RowExpression must return vec'); - - $right_elems = $right->evaluate($row, $conn); - invariant($right_elems is vec<_>, 'RowExpression must return vec'); - - if (C\count($left_elems) !== C\count($right_elems)) { - throw new SQLFakeRuntimeException('Mismatched column count in row comparison expression'); - } - - $last_index = C\last_key($left_elems); - - foreach ($left_elems as $index => $le) { - $re = $right_elems[$index]; - - // in an expression like (1, 2, 3) > (1, 2, 2) we don't need EVERY element on the left to be greater than the right - // some can be equal. so if we get to one that isn't the last and they're equal, it's safe to keep going - if (\HH\Lib\Legacy_FIXME\eq($le, $re) && $index !== $last_index) { - continue; - } - - // as soon as you find any pair of elements that aren't equal, you can return whatever their comparison result is immediately - // this is why (1, 2, 3) > (1, 1, 4) is true, for example, because the 2nd element comparison returns immediately - switch ($this->operator as Operator) { - case Operator::EQUALS: - return ($le == $re); - case Operator::LESS_THAN_EQUALS_GREATER_THAN: - case Operator::BANG_EQUALS: - return ($le != $re); - case Operator::GREATER_THAN: - /* HH_IGNORE_ERROR[4240] assume they have the same types */ - return ($le > $re); - case Operator::GREATER_THAN_EQUALS: - /* HH_IGNORE_ERROR[4240] assume they have the same types */ - return ($le >= $re); - case Operator::LESS_THAN: - /* HH_IGNORE_ERROR[4240] assume they have the same types */ - return \HH\Lib\Legacy_FIXME\lt($le, $re); - case Operator::LESS_THAN_EQUALS: - /* HH_IGNORE_ERROR[4240] assume they have the same types */ - return ($le <= $re); - default: - throw new SQLFakeRuntimeException("Operand {$this->operator} should contain 1 column(s)"); - } - } - - return false; - } - - <<__Override>> - public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { - $right = $this->right; - $left = $this->left; - - if ($left is RowExpression) { - if (!$right is RowExpression) { - throw new SQLFakeRuntimeException('Expected row expression on RHS of '.(string)$this->operator.' operand'); - } - - // oh fun! a row comparison, e.g. (col1, col2, col3) > (1, 2, 3) - // these are handled somewhat differently from all other binary operands since you need to loop and compare each element - // also we cast to int because that's how MySQL would return these - return $this->evaluateRowComparison($left, $right, $row, $conn); - } - - if ($right === null) { - throw new SQLFakeRuntimeException('Attempted to evaluate BinaryOperatorExpression with no right operand'); - } - - $as_string = $left->getType() == TokenType::STRING_CONSTANT || $right->getType() == TokenType::STRING_CONSTANT; - - $op = $this->operator; - if ($op === null) { - // an operator should only be in this state in the middle of parsing, never when evaluating - throw new SQLFakeRuntimeException('Attempted to evaluate BinaryOperatorExpression with empty operator'); - } - - // special handling for AND/OR - when possible, return without evaluating $right - if ($op === Operator::AND) { - $l_value = $left->evaluate($row, $conn); - if (!$l_value) { - return $this->negated; - } - $r_value = $right->evaluate($row, $conn); - if (!$r_value) { - return $this->negated; - } - return !$this->negated; - } else if ($op === Operator::OR) { - $l_value = $left->evaluate($row, $conn); - if ($l_value) { - return !$this->negated; - } - $r_value = $right->evaluate($row, $conn); - if ($r_value) { - return !$this->negated; - } - return $this->negated; - } - - $l_value = $left->evaluate($row, $conn); - $r_value = $right->evaluate($row, $conn); - - switch ($op) { - case Operator::AND: - case Operator::OR: - invariant(false, 'impossible to arrive here'); - case Operator::EQUALS: - // maybe do some stuff with data types here - // comparing strings: gotta think about collation and case sensitivity! - return (bool)(\HH\Lib\Legacy_FIXME\eq($l_value, $r_value) ? 1 : 0 ^ $this->negatedInt); - case Operator::LESS_THAN_GREATER_THAN: - case Operator::BANG_EQUALS: - if ($as_string) { - return (bool)(((string)$l_value != (string)$r_value) ? 1 : 0 ^ $this->negatedInt); - } else { - return (bool)(((float)$l_value != (float)$r_value) ? 1 : 0 ^ $this->negatedInt); - } - case Operator::GREATER_THAN: - if ($as_string) { - return (bool)((((Str\compare((string)$l_value, (string)$r_value)) > 0) ? 1 : 0) ^ $this->negatedInt); - } else { - return (bool)(((float)$l_value > (float)$r_value) ? 1 : 0 ^ $this->negatedInt); - } - case Operator::GREATER_THAN_EQUALS: - if ($as_string) { - $comparison = Str\compare((string)$l_value, (string)$r_value); - return (bool)((($comparison > 0 || $comparison === 0) ? 1 : 0) ^ $this->negatedInt); - } else { - return (bool)(((float)$l_value >= (float)$r_value) ? 1 : 0 ^ $this->negatedInt); - } - case Operator::LESS_THAN: - if ($as_string) { - return (bool)((((Str\compare((string)$l_value, (string)$r_value)) < 0) ? 1 : 0) ^ $this->negatedInt); - } else { - return (bool)(((float)$l_value < (float)$r_value) ? 1 : 0 ^ $this->negatedInt); - } - case Operator::LESS_THAN_EQUALS: - if ($as_string) { - $comparison = Str\compare((string)$l_value, (string)$r_value); - return (bool)((($comparison < 0 || $comparison === 0) ? 1 : 0) ^ $this->negatedInt); - } else { - return (bool)(((float)$l_value <= (float)$r_value) ? 1 : 0 ^ $this->negatedInt); - } - case Operator::ASTERISK: - case Operator::PERCENT: - case Operator::MOD: - case Operator::MINUS: - case Operator::PLUS: - case Operator::DOUBLE_LESS_THAN: - case Operator::DOUBLE_GREATER_THAN: - case Operator::FORWARD_SLASH: - case Operator::DIV: - // do these things to all numeric operators and then switch again to execute the actual operation - $left_number = $this->extractNumericValue($l_value); - $right_number = $this->extractNumericValue($r_value); - - switch ($op) { - case Operator::ASTERISK: - return $left_number * $right_number; - case Operator::PERCENT: - case Operator::MOD: - // mod is float-aware, not ints only like Hack's % operator - return \fmod((float)$left_number, (float)$right_number); - case Operator::FORWARD_SLASH: - return $left_number / $right_number; - case Operator::DIV: - // integer division - return (int)($left_number / $right_number); - case Operator::MINUS: - return $left_number - $right_number; - case Operator::PLUS: - return $left_number + $right_number; - case Operator::DOUBLE_LESS_THAN: - return (int)$left_number << (int)$right_number; - case Operator::DOUBLE_GREATER_THAN: - return (int)$left_number >> (int)$right_number; - default: - throw new SQLFakeRuntimeException('Operator '.(string)$this->operator.' recognized but not implemented'); - } - case Operator::LIKE: - $left_string = (string)$left->evaluate($row, $conn); - if (!$right is ConstantExpression) { - throw new SQLFakeRuntimeException('LIKE pattern should be a constant string'); - } - $pattern = (string)$r_value; - - $start_pattern = '^'; - $end_pattern = '$'; - - if ($pattern[0] == '%') { - $start_pattern = ''; - $pattern = Str\strip_prefix($pattern, '%'); - } - - if (Str\ends_with($pattern, '%')) { - $end_pattern = ''; - $pattern = Str\strip_suffix($pattern, '%'); - } - - // escape all + characters - $pattern = \preg_quote($pattern, '/'); - - // replace only unescaped % and _ characters to make regex - $pattern = Regex\replace($pattern, re"/(?negatedInt); - case Operator::IS: - if (!$right is ConstantExpression) { - throw new SQLFakeRuntimeException('Unsupported right operand for IS keyword'); - } - $val = $left->evaluate($row, $conn); - - $r = $r_value; - - if ($r === null) { - return (bool)(($val === null ? 1 : 0) ^ $this->negatedInt); - } - - // you can also do IS TRUE, IS FALSE, or IS UNKNOWN but I haven't implemented that yet mostly because those come through the parser as "RESERVED" rather than const expressions - - throw new SQLFakeRuntimeException('Unsupported right operand for IS keyword'); - case Operator::RLIKE: - case Operator::REGEXP: - $left_string = (string)$left->evaluate($row, $conn); - // if the regexp is wrapped in a BINARY function we will make it case sensitive - $case_insensitive = 'i'; - if ($right is FunctionExpression && $right->functionName() == 'BINARY') { - $case_insensitive = ''; - } - - $pattern = (string)$r_value; - $regex = '/'.$pattern.'/'.$case_insensitive; - - // xor here, so if negated is true and regex matches then we return false etc. - return (bool)(((bool)\preg_match($regex, $left_string) ? 1 : 0) ^ $this->negatedInt); - case Operator::AMPERSAND: - return (int)$l_value & (int)$r_value; - case Operator::DOUBLE_AMPERSAND: - case Operator::BINARY: - case Operator::COLLATE: - case Operator::PIPE: - case Operator::CARET: - case Operator::LESS_THAN_EQUALS_GREATER_THAN: - case Operator::DOUBLE_PIPE: - case Operator::XOR: - case Operator::SOUNDS: - case Operator::ANY: // parser does NOT KNOW about this functionality - case Operator::SOME: // parser does NOT KNOW about this functionality - //[[fallthrough]] <- note to humans, not to the typechecker, therefore different syntax. - default: - throw new SQLFakeRuntimeException('Operator '.(string)$this->operator.' not implemented in SQLFake'); - } - } - - /** - * Coerce a mixed value to a num, - * but also handle sub-expressions that return a dataset containing a num - * such as "SELECT (SELECT COUNT(*) FROM ...) + 3 as thecountplusthree" - */ - protected function extractNumericValue(mixed $val): num { - if ($val is Container<_>) { - if (C\is_empty($val)) { - $val = 0; - } else { - // extract first row, then first column - $val = (C\firstx($val) as Container<_>) |> C\firstx($$); - } - } - return Str\contains((string)$val, '.') ? (float)$val : (int)$val; - } - - <<__Override>> - public function negate(): void { - $this->negated = true; - $this->negatedInt = 1; - } - - <<__Override>> - public function __debugInfo(): dict { - - $ret = dict[ - 'type' => $this->operator, - 'left' => \var_dump($this->left, true), - 'right' => $this->right ? \var_dump($this->right, true) : dict[], - ]; - - if (!Str\is_empty($this->name)) { - $ret['name'] = $this->name; - } - if ($this->negated) { - $ret['negated'] = 1; - } - return $ret; - } - - <<__Override>> - public function isWellFormed(): bool { - return $this->right && $this->operator is nonnull; - } - - <<__Override>> - public function setNextChild(Expression $expr, bool $overwrite = false): void { - if ($this->operator === null || ($this->right && !$overwrite)) { - throw new SQLFakeParseException('Parse error'); - } - $this->right = $expr; - } - - public function setOperator(Operator $operator): void { - $this->operator = $operator; - $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($operator)]; - } - - public function getRightOrThrow(): Expression { - if ($this->right === null) { - throw new SQLFakeParseException('Parse error: attempted to resolve unbound expression'); - } - return $this->right; - } - - public function traverse(): vec { - $container = vec[]; - - if ($this->left is nonnull) { - if ($this->left is BinaryOperatorExpression) { - $container = Vec\concat($container, $this->left->traverse()); - } else { - $container[] = $this->left; - } - } - - if ($this->right is nonnull) { - if ($this->right is BinaryOperatorExpression) { - $container = Vec\concat($container, $this->right->traverse()); - } else { - $container[] = $this->right; - } - } - - return $container; - } - - <<__Override>> - public function addRecursiveExpression(token_list $tokens, int $pointer, bool $negated = false): int { - // this might not end up as a binary expression, but it is ok for it to start that way! - // $right could be empty here if we encountered an expression on the right hand side of an operator like, "column_name = CASE.." - $tmp = $this->right ? new BinaryOperatorExpression($this->right) : new PlaceholderExpression(); - - // what we want to do is tell the child to process itself until it finds a precedence lower than the parent - $p = new ExpressionParser($tokens, $pointer, $tmp, $this->precedence, true); - list($pointer, $new_expression) = $p->buildWithPointer(); - - if ($negated) { - $new_expression->negate(); - } - - $this->setNextChild($new_expression, true); - - return $pointer; - } + protected bool $evaluates_groups = false; + protected int $negatedInt = 0; + + public function __construct( + public Expression $left, // public because we sometimes need to access it to split off into a BETWEEN + public bool $negated = false, + public ?Operator $operator = null, + public ?Expression $right = null, + ) { + $this->name = ''; + // this gets overwritten once we have an operator + $this->precedence = 0; + $this->type = TokenType::OPERATOR; + if ($operator is nonnull) { + $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($operator)]; + } + + $this->negatedInt = $this->negated ? 1 : 0; + } + + /** + * Runs the comparison on each element between the left and right, + * BUT if the values are equal it keeps checking down the list + * (1, 2, 3) > (1, 2, 2) for example + * (1, 2, 3) > (1, 1, 4) is also true + */ + private function evaluateRowComparison( + RowExpression $left, + RowExpression $right, + row $row, + AsyncMysqlConnection $conn, + ): bool { + + $left_elems = $left->evaluate($row, $conn); + invariant($left_elems is vec<_>, 'RowExpression must return vec'); + + $right_elems = $right->evaluate($row, $conn); + invariant($right_elems is vec<_>, 'RowExpression must return vec'); + + if (C\count($left_elems) !== C\count($right_elems)) { + throw new SQLFakeRuntimeException('Mismatched column count in row comparison expression'); + } + + $last_index = C\last_key($left_elems); + + foreach ($left_elems as $index => $le) { + $re = $right_elems[$index]; + + // in an expression like (1, 2, 3) > (1, 2, 2) we don't need EVERY element on the left to be greater than the right + // some can be equal. so if we get to one that isn't the last and they're equal, it's safe to keep going + if (\HH\Lib\Legacy_FIXME\eq($le, $re) && $index !== $last_index) { + continue; + } + + // as soon as you find any pair of elements that aren't equal, you can return whatever their comparison result is immediately + // this is why (1, 2, 3) > (1, 1, 4) is true, for example, because the 2nd element comparison returns immediately + switch ($this->operator as Operator) { + case Operator::EQUALS: + return ($le == $re); + case Operator::LESS_THAN_EQUALS_GREATER_THAN: + case Operator::BANG_EQUALS: + return ($le != $re); + case Operator::GREATER_THAN: + /* HH_IGNORE_ERROR[4240] assume they have the same types */ + return ($le > $re); + case Operator::GREATER_THAN_EQUALS: + /* HH_IGNORE_ERROR[4240] assume they have the same types */ + return ($le >= $re); + case Operator::LESS_THAN: + /* HH_IGNORE_ERROR[4240] assume they have the same types */ + return \HH\Lib\Legacy_FIXME\lt($le, $re); + case Operator::LESS_THAN_EQUALS: + /* HH_IGNORE_ERROR[4240] assume they have the same types */ + return ($le <= $re); + default: + throw new SQLFakeRuntimeException("Operand {$this->operator} should contain 1 column(s)"); + } + } + + return false; + } + + <<__Override>> + public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { + $right = $this->right; + $left = $this->left; + + if ($left is RowExpression) { + if (!$right is RowExpression) { + throw new SQLFakeRuntimeException('Expected row expression on RHS of '.(string)$this->operator.' operand'); + } + + // oh fun! a row comparison, e.g. (col1, col2, col3) > (1, 2, 3) + // these are handled somewhat differently from all other binary operands since you need to loop and compare each element + // also we cast to int because that's how MySQL would return these + return $this->evaluateRowComparison($left, $right, $row, $conn); + } + + if ($right === null) { + throw new SQLFakeRuntimeException('Attempted to evaluate BinaryOperatorExpression with no right operand'); + } + + $as_string = $left->getType() == TokenType::STRING_CONSTANT || $right->getType() == TokenType::STRING_CONSTANT; + + $op = $this->operator; + if ($op === null) { + // an operator should only be in this state in the middle of parsing, never when evaluating + throw new SQLFakeRuntimeException('Attempted to evaluate BinaryOperatorExpression with empty operator'); + } + + // special handling for AND/OR - when possible, return without evaluating $right + if ($op === Operator::AND) { + $l_value = $left->evaluate($row, $conn); + if (!$l_value) { + return $this->negated; + } + $r_value = $right->evaluate($row, $conn); + if (!$r_value) { + return $this->negated; + } + return !$this->negated; + } else if ($op === Operator::OR) { + $l_value = $left->evaluate($row, $conn); + if ($l_value) { + return !$this->negated; + } + $r_value = $right->evaluate($row, $conn); + if ($r_value) { + return !$this->negated; + } + return $this->negated; + } + + $l_value = $left->evaluate($row, $conn); + $r_value = $right->evaluate($row, $conn); + + switch ($op) { + case Operator::AND: + case Operator::OR: + invariant(false, 'impossible to arrive here'); + case Operator::EQUALS: + // maybe do some stuff with data types here + // comparing strings: gotta think about collation and case sensitivity! + return (bool)(\HH\Lib\Legacy_FIXME\eq($l_value, $r_value) ? 1 : 0 ^ $this->negatedInt); + case Operator::LESS_THAN_GREATER_THAN: + case Operator::BANG_EQUALS: + if ($as_string) { + return (bool)(((string)$l_value != (string)$r_value) ? 1 : 0 ^ $this->negatedInt); + } else { + return (bool)(((float)$l_value != (float)$r_value) ? 1 : 0 ^ $this->negatedInt); + } + case Operator::GREATER_THAN: + if ($as_string) { + return (bool)((((Str\compare((string)$l_value, (string)$r_value)) > 0) ? 1 : 0) ^ $this->negatedInt); + } else { + return (bool)(((float)$l_value > (float)$r_value) ? 1 : 0 ^ $this->negatedInt); + } + case Operator::GREATER_THAN_EQUALS: + if ($as_string) { + $comparison = Str\compare((string)$l_value, (string)$r_value); + return (bool)((($comparison > 0 || $comparison === 0) ? 1 : 0) ^ $this->negatedInt); + } else { + return (bool)(((float)$l_value >= (float)$r_value) ? 1 : 0 ^ $this->negatedInt); + } + case Operator::LESS_THAN: + if ($as_string) { + return (bool)((((Str\compare((string)$l_value, (string)$r_value)) < 0) ? 1 : 0) ^ $this->negatedInt); + } else { + return (bool)(((float)$l_value < (float)$r_value) ? 1 : 0 ^ $this->negatedInt); + } + case Operator::LESS_THAN_EQUALS: + if ($as_string) { + $comparison = Str\compare((string)$l_value, (string)$r_value); + return (bool)((($comparison < 0 || $comparison === 0) ? 1 : 0) ^ $this->negatedInt); + } else { + return (bool)(((float)$l_value <= (float)$r_value) ? 1 : 0 ^ $this->negatedInt); + } + case Operator::ASTERISK: + case Operator::PERCENT: + case Operator::MOD: + case Operator::MINUS: + case Operator::PLUS: + case Operator::DOUBLE_LESS_THAN: + case Operator::DOUBLE_GREATER_THAN: + case Operator::FORWARD_SLASH: + case Operator::DIV: + // do these things to all numeric operators and then switch again to execute the actual operation + $left_number = $this->extractNumericValue($l_value); + $right_number = $this->extractNumericValue($r_value); + + switch ($op) { + case Operator::ASTERISK: + return $left_number * $right_number; + case Operator::PERCENT: + case Operator::MOD: + // mod is float-aware, not ints only like Hack's % operator + return \fmod((float)$left_number, (float)$right_number); + case Operator::FORWARD_SLASH: + return $left_number / $right_number; + case Operator::DIV: + // integer division + return (int)($left_number / $right_number); + case Operator::MINUS: + return $left_number - $right_number; + case Operator::PLUS: + return $left_number + $right_number; + case Operator::DOUBLE_LESS_THAN: + return (int)$left_number << (int)$right_number; + case Operator::DOUBLE_GREATER_THAN: + return (int)$left_number >> (int)$right_number; + default: + throw new SQLFakeRuntimeException('Operator '.(string)$this->operator.' recognized but not implemented'); + } + case Operator::LIKE: + $left_string = (string)$left->evaluate($row, $conn); + if (!$right is ConstantExpression) { + throw new SQLFakeRuntimeException('LIKE pattern should be a constant string'); + } + $pattern = (string)$r_value; + + $start_pattern = '^'; + $end_pattern = '$'; + + if ($pattern[0] == '%') { + $start_pattern = ''; + $pattern = Str\strip_prefix($pattern, '%'); + } + + if (Str\ends_with($pattern, '%')) { + $end_pattern = ''; + $pattern = Str\strip_suffix($pattern, '%'); + } + + // escape all + characters + $pattern = \preg_quote($pattern, '/'); + + // replace only unescaped % and _ characters to make regex + $pattern = Regex\replace($pattern, re"/(?negatedInt); + case Operator::IS: + if (!$right is ConstantExpression) { + throw new SQLFakeRuntimeException('Unsupported right operand for IS keyword'); + } + $val = $left->evaluate($row, $conn); + + $r = $r_value; + + if ($r === null) { + return (bool)(($val === null ? 1 : 0) ^ $this->negatedInt); + } + + // you can also do IS TRUE, IS FALSE, or IS UNKNOWN but I haven't implemented that yet mostly because those come through the parser as "RESERVED" rather than const expressions + + throw new SQLFakeRuntimeException('Unsupported right operand for IS keyword'); + case Operator::RLIKE: + case Operator::REGEXP: + $left_string = (string)$left->evaluate($row, $conn); + // if the regexp is wrapped in a BINARY function we will make it case sensitive + $case_insensitive = 'i'; + if ($right is FunctionExpression && $right->functionName() == 'BINARY') { + $case_insensitive = ''; + } + + $pattern = (string)$r_value; + $regex = '/'.$pattern.'/'.$case_insensitive; + + // xor here, so if negated is true and regex matches then we return false etc. + return (bool)(((bool)\preg_match($regex, $left_string) ? 1 : 0) ^ $this->negatedInt); + case Operator::AMPERSAND: + return (int)$l_value & (int)$r_value; + case Operator::DOUBLE_AMPERSAND: + case Operator::BINARY: + case Operator::COLLATE: + case Operator::PIPE: + case Operator::CARET: + case Operator::LESS_THAN_EQUALS_GREATER_THAN: + case Operator::DOUBLE_PIPE: + case Operator::XOR: + case Operator::SOUNDS: + case Operator::ANY: // parser does NOT KNOW about this functionality + case Operator::SOME: // parser does NOT KNOW about this functionality + //[[fallthrough]] <- note to humans, not to the typechecker, therefore different syntax. + default: + throw new SQLFakeRuntimeException('Operator '.(string)$this->operator.' not implemented in SQLFake'); + } + } + + /** + * Coerce a mixed value to a num, + * but also handle sub-expressions that return a dataset containing a num + * such as "SELECT (SELECT COUNT(*) FROM ...) + 3 as thecountplusthree" + */ + protected function extractNumericValue(mixed $val): num { + if ($val is Container<_>) { + if (C\is_empty($val)) { + $val = 0; + } else { + // extract first row, then first column + $val = (C\firstx($val) as Container<_>) |> C\firstx($$); + } + } + return Str\contains((string)$val, '.') ? (float)$val : (int)$val; + } + + <<__Override>> + public function negate(): void { + $this->negated = true; + $this->negatedInt = 1; + } + + <<__Override>> + public function __debugInfo(): dict { + + $ret = dict[ + 'type' => $this->operator, + 'left' => \var_dump($this->left, true), + 'right' => $this->right ? \var_dump($this->right, true) : dict[], + ]; + + if (!Str\is_empty($this->name)) { + $ret['name'] = $this->name; + } + if ($this->negated) { + $ret['negated'] = 1; + } + return $ret; + } + + <<__Override>> + public function isWellFormed(): bool { + return $this->right && $this->operator is nonnull; + } + + <<__Override>> + public function setNextChild(Expression $expr, bool $overwrite = false): void { + if ($this->operator === null || ($this->right && !$overwrite)) { + throw new SQLFakeParseException('Parse error'); + } + $this->right = $expr; + } + + public function setOperator(Operator $operator): void { + $this->operator = $operator; + $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($operator)]; + } + + public function getRightOrThrow(): Expression { + if ($this->right === null) { + throw new SQLFakeParseException('Parse error: attempted to resolve unbound expression'); + } + return $this->right; + } + + public function traverse(): vec { + $container = vec[]; + + if ($this->left is nonnull) { + if ($this->left is BinaryOperatorExpression) { + $container = Vec\concat($container, $this->left->traverse()); + } else { + $container[] = $this->left; + } + } + + if ($this->right is nonnull) { + if ($this->right is BinaryOperatorExpression) { + $container = Vec\concat($container, $this->right->traverse()); + } else { + $container[] = $this->right; + } + } + + return $container; + } + + <<__Override>> + public function addRecursiveExpression(token_list $tokens, int $pointer, bool $negated = false): int { + // this might not end up as a binary expression, but it is ok for it to start that way! + // $right could be empty here if we encountered an expression on the right hand side of an operator like, "column_name = CASE.." + $tmp = $this->right ? new BinaryOperatorExpression($this->right) : new PlaceholderExpression(); + + // what we want to do is tell the child to process itself until it finds a precedence lower than the parent + $p = new ExpressionParser($tokens, $pointer, $tmp, $this->precedence, true); + list($pointer, $new_expression) = $p->buildWithPointer(); + + if ($negated) { + $new_expression->negate(); + } + + $this->setNextChild($new_expression, true); + + return $pointer; + } } diff --git a/src/Expressions/CaseOperatorExpression.php b/src/Expressions/CaseOperatorExpression.php index ba6aaef..9abd21e 100644 --- a/src/Expressions/CaseOperatorExpression.php +++ b/src/Expressions/CaseOperatorExpression.php @@ -9,147 +9,147 @@ */ final class CaseOperatorExpression extends Expression { - private vec Expression, - 'then' => Expression, - )> $whenExpressions = vec[]; - - private ?Expression $when; - private ?Expression $then; - private ?Expression $else; - private string $lastKeyword = 'CASE'; - private bool $wellFormed = false; - - public function __construct(token $_token) { - $op = Operator::CASE; - $this->name = operator_to_string($op); - $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($op)]; - $this->operator = $op; - $this->type = TokenType::OPERATOR; - } - - <<__Override>> - public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { - if (!$this->wellFormed) { - throw new SQLFakeRuntimeException('Attempted to evaluate incomplete CASE expression'); - } - - foreach ($this->whenExpressions as $clause) { - if ((bool)$clause['when']->evaluate($row, $conn)) { - return $clause['then']->evaluate($row, $conn); - } - } - - invariant($this->else is nonnull, 'must have else since wellFormed was true'); - return $this->else->evaluate($row, $conn); - } - - <<__Override>> - public function isWellFormed(): bool { - return $this->wellFormed; - } - - public function setKeyword(string $keyword): void { - switch ($keyword) { - case 'WHEN': - if ($this->lastKeyword !== 'CASE' && $this->lastKeyword !== 'THEN') { - throw new SQLFakeParseException('Unexpected WHEN in CASE statement'); - } - $this->lastKeyword = 'WHEN'; - // set these to null in case this is not the first WHEN clause, so that the clauses know to accept expressions - $this->when = null; - $this->then = null; - break; - case 'THEN': - if ($this->lastKeyword !== 'WHEN' || !$this->when) { - throw new SQLFakeParseException('Unexpected THEN in CASE statement'); - } - $this->lastKeyword = 'THEN'; - break; - case 'ELSE': - if ($this->lastKeyword !== 'THEN' || !$this->then) { - throw new SQLFakeParseException('Unexpected ELSE in CASE statement'); - } - $this->lastKeyword = 'ELSE'; - break; - case 'END': - // ELSE clause is optional, it becomes an "ELSE NULL" implicitly if not present - if ($this->lastKeyword === 'THEN' && $this->then) { - $this->else = new ConstantExpression(shape( - 'type' => TokenType::NULL_CONSTANT, - 'value' => 'null', - 'raw' => 'null', - )); - } else if ($this->lastKeyword !== 'ELSE' || !$this->else) { - throw new SQLFakeParseException('Unexpected END in CASE statement'); - } - $this->lastKeyword = 'END'; - $this->wellFormed = true; - break; - default: - throw new SQLFakeParseException("Unexpected keyword $keyword in CASE statement"); - } - } - - <<__Override>> - public function setNextChild(Expression $expr, bool $overwrite = false): void { - switch ($this->lastKeyword) { - case 'CASE': - throw new SQLFakeParseException('Missing WHEN in CASE'); - case 'WHEN': - if ($this->when && !$overwrite) { - throw new SQLFakeParseException('Unexpected token near WHEN'); - } - $this->when = $expr; - break; - case 'THEN': - if ($this->then && !$overwrite) { - throw new SQLFakeParseException('Unexpected token near THEN'); - } - $this->then = $expr; - $this->whenExpressions[] = shape('when' => $this->when as nonnull, 'then' => $expr); - break; - case 'ELSE': - if ($this->else && !$overwrite) { - throw new SQLFakeParseException('Unexpected token near ELSE'); - } - $this->else = $expr; - break; - case 'END': - throw new SQLFakeParseException('Unexpected token near END'); - } - } - - <<__Override>> - public function addRecursiveExpression(token_list $tokens, int $pointer, bool $negated = false): int { - $p = new ExpressionParser($tokens, $pointer, new PlaceholderExpression(), 0, true); - list($pointer, $new_expression) = $p->buildWithPointer(); - - if ($negated) { - $new_expression->negate(); - } - - // the way case statements are parsed... we actually do not want to overwrite - $this->setNextChild($new_expression, false); - return $pointer; - } - - <<__Override>> - public function __debugInfo(): dict { - invariant(!C\is_empty($this->whenExpressions), 'There must be at least one whenExpression'); - $when_list = vec[]; - foreach ($this->whenExpressions as $exp) { - $when_list[] = dict['when' => \var_dump($exp['when'], true), 'then' => \var_dump($exp['then'], true)]; - } - $ret = dict[ - 'type' => (string)$this->type, - 'whenExpressions' => $when_list, - 'else' => $this->else ? \var_dump($this->else, true) : dict[], - ]; - - if (!Str\is_empty($this->name)) { - $ret['name'] = $this->name; - } - return $ret; - } + private vec Expression, + 'then' => Expression, + )> $whenExpressions = vec[]; + + private ?Expression $when; + private ?Expression $then; + private ?Expression $else; + private string $lastKeyword = 'CASE'; + private bool $wellFormed = false; + + public function __construct(token $_token) { + $op = Operator::CASE; + $this->name = operator_to_string($op); + $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($op)]; + $this->operator = $op; + $this->type = TokenType::OPERATOR; + } + + <<__Override>> + public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { + if (!$this->wellFormed) { + throw new SQLFakeRuntimeException('Attempted to evaluate incomplete CASE expression'); + } + + foreach ($this->whenExpressions as $clause) { + if ((bool)$clause['when']->evaluate($row, $conn)) { + return $clause['then']->evaluate($row, $conn); + } + } + + invariant($this->else is nonnull, 'must have else since wellFormed was true'); + return $this->else->evaluate($row, $conn); + } + + <<__Override>> + public function isWellFormed(): bool { + return $this->wellFormed; + } + + public function setKeyword(string $keyword): void { + switch ($keyword) { + case 'WHEN': + if ($this->lastKeyword !== 'CASE' && $this->lastKeyword !== 'THEN') { + throw new SQLFakeParseException('Unexpected WHEN in CASE statement'); + } + $this->lastKeyword = 'WHEN'; + // set these to null in case this is not the first WHEN clause, so that the clauses know to accept expressions + $this->when = null; + $this->then = null; + break; + case 'THEN': + if ($this->lastKeyword !== 'WHEN' || !$this->when) { + throw new SQLFakeParseException('Unexpected THEN in CASE statement'); + } + $this->lastKeyword = 'THEN'; + break; + case 'ELSE': + if ($this->lastKeyword !== 'THEN' || !$this->then) { + throw new SQLFakeParseException('Unexpected ELSE in CASE statement'); + } + $this->lastKeyword = 'ELSE'; + break; + case 'END': + // ELSE clause is optional, it becomes an "ELSE NULL" implicitly if not present + if ($this->lastKeyword === 'THEN' && $this->then) { + $this->else = new ConstantExpression(shape( + 'type' => TokenType::NULL_CONSTANT, + 'value' => 'null', + 'raw' => 'null', + )); + } else if ($this->lastKeyword !== 'ELSE' || !$this->else) { + throw new SQLFakeParseException('Unexpected END in CASE statement'); + } + $this->lastKeyword = 'END'; + $this->wellFormed = true; + break; + default: + throw new SQLFakeParseException("Unexpected keyword $keyword in CASE statement"); + } + } + + <<__Override>> + public function setNextChild(Expression $expr, bool $overwrite = false): void { + switch ($this->lastKeyword) { + case 'CASE': + throw new SQLFakeParseException('Missing WHEN in CASE'); + case 'WHEN': + if ($this->when && !$overwrite) { + throw new SQLFakeParseException('Unexpected token near WHEN'); + } + $this->when = $expr; + break; + case 'THEN': + if ($this->then && !$overwrite) { + throw new SQLFakeParseException('Unexpected token near THEN'); + } + $this->then = $expr; + $this->whenExpressions[] = shape('when' => $this->when as nonnull, 'then' => $expr); + break; + case 'ELSE': + if ($this->else && !$overwrite) { + throw new SQLFakeParseException('Unexpected token near ELSE'); + } + $this->else = $expr; + break; + case 'END': + throw new SQLFakeParseException('Unexpected token near END'); + } + } + + <<__Override>> + public function addRecursiveExpression(token_list $tokens, int $pointer, bool $negated = false): int { + $p = new ExpressionParser($tokens, $pointer, new PlaceholderExpression(), 0, true); + list($pointer, $new_expression) = $p->buildWithPointer(); + + if ($negated) { + $new_expression->negate(); + } + + // the way case statements are parsed... we actually do not want to overwrite + $this->setNextChild($new_expression, false); + return $pointer; + } + + <<__Override>> + public function __debugInfo(): dict { + invariant(!C\is_empty($this->whenExpressions), 'There must be at least one whenExpression'); + $when_list = vec[]; + foreach ($this->whenExpressions as $exp) { + $when_list[] = dict['when' => \var_dump($exp['when'], true), 'then' => \var_dump($exp['then'], true)]; + } + $ret = dict[ + 'type' => (string)$this->type, + 'whenExpressions' => $when_list, + 'else' => $this->else ? \var_dump($this->else, true) : dict[], + ]; + + if (!Str\is_empty($this->name)) { + $ret['name'] = $this->name; + } + return $ret; + } } diff --git a/src/Expressions/ColumnExpression.php b/src/Expressions/ColumnExpression.php index 0bf34f3..7e8f6a6 100644 --- a/src/Expressions/ColumnExpression.php +++ b/src/Expressions/ColumnExpression.php @@ -6,102 +6,102 @@ // extracts a column from a row by index final class ColumnExpression extends Expression { - private string $columnExpression; - private string $columnName; - private ?string $tableName; - private ?string $databaseName; - private bool $allowFallthrough = false; - - public function __construct(token $token) { - $this->type = $token['type']; - $this->precedence = 0; - - $this->columnExpression = $token['value']; - $this->columnName = $token['value']; - - // TODO handle database schema here - if (Str\contains($token['value'], '.')) { - $parts = Str\split($token['value'], '.'); - if (C\count($parts) === 2) { - list($this->tableName, $this->columnName) = $parts; - } else if (C\count($parts) === 3) { - list($this->databaseName, $this->tableName, $this->columnName) = $parts; - } - } else { - $this->tableName = null; - } - - if ($token['value'] === '*') { - $this->name = '*'; - return; - } - - $this->name = $this->columnName; - } - - <<__Override>> - public function evaluateImpl(row $row, AsyncMysqlConnection $_conn): mixed { - // for the "COUNT(*)" case, just return 1 - // we don't actually implement "*" in this library, the select processer handles that - if ($this->name === '*') { - return 1; - } - - $row = $this->maybeUnrollGroupedDataset($row); - - // otherwise return the column - if (C\contains_key($row, $this->columnExpression)) { - return $row[$this->columnExpression]; - } else if (($this->tableName === null && $this->columnName is nonnull) || $this->allowFallthrough) { - // didn't find row by alias, so search without alias instead - // but only if the column expression didn't have an explicit table name on it - // OR if we are explicitly allowing fallthrough to the full row, which we do in the ORDER BY clause - $dot_column_name = '.'.$this->columnName; - foreach ($row as $key => $col) { - if (Str\ends_with($key, $dot_column_name)) { - return $col; - } - } - } - - if (C\contains_key($row, $this->name)) { - return $row[$this->name]; - } - - if (QueryContext::$strictSchemaMode) { - // we've running in strict mode but we still ran into a column that was missing. - // this means we're selecting on a column that does not exist - throw new SQLFakeRuntimeException('Column with index '.$this->columnExpression.' not found in row'); - } else { - return null; - } - } - - /** - * for use in ORDER BY... allow evaluating the expression - * to fall through to the full row if the column is not found fully qualified. - */ - public function allowFallthrough(): void { - $this->allowFallthrough = true; - } - - <<__Override>> - public function isWellFormed(): bool { - return true; - } - - public function tableName(): ?string { - return $this->tableName; - } - - public function prefixColumnExpression(string $prefix): void { - if (!Str\starts_with($this->columnExpression, $prefix)) { - $this->columnExpression = $prefix.$this->columnExpression; - } - } - - <<__Override>> - public function __debugInfo(): dict { - return dict['type' => 'colref', 'name' => $this->name]; - } + private string $columnExpression; + private string $columnName; + private ?string $tableName; + private ?string $databaseName; + private bool $allowFallthrough = false; + + public function __construct(token $token) { + $this->type = $token['type']; + $this->precedence = 0; + + $this->columnExpression = $token['value']; + $this->columnName = $token['value']; + + // TODO handle database schema here + if (Str\contains($token['value'], '.')) { + $parts = Str\split($token['value'], '.'); + if (C\count($parts) === 2) { + list($this->tableName, $this->columnName) = $parts; + } else if (C\count($parts) === 3) { + list($this->databaseName, $this->tableName, $this->columnName) = $parts; + } + } else { + $this->tableName = null; + } + + if ($token['value'] === '*') { + $this->name = '*'; + return; + } + + $this->name = $this->columnName; + } + + <<__Override>> + public function evaluateImpl(row $row, AsyncMysqlConnection $_conn): mixed { + // for the "COUNT(*)" case, just return 1 + // we don't actually implement "*" in this library, the select processer handles that + if ($this->name === '*') { + return 1; + } + + $row = $this->maybeUnrollGroupedDataset($row); + + // otherwise return the column + if (C\contains_key($row, $this->columnExpression)) { + return $row[$this->columnExpression]; + } else if (($this->tableName === null && $this->columnName is nonnull) || $this->allowFallthrough) { + // didn't find row by alias, so search without alias instead + // but only if the column expression didn't have an explicit table name on it + // OR if we are explicitly allowing fallthrough to the full row, which we do in the ORDER BY clause + $dot_column_name = '.'.$this->columnName; + foreach ($row as $key => $col) { + if (Str\ends_with($key, $dot_column_name)) { + return $col; + } + } + } + + if (C\contains_key($row, $this->name)) { + return $row[$this->name]; + } + + if (QueryContext::$strictSchemaMode) { + // we've running in strict mode but we still ran into a column that was missing. + // this means we're selecting on a column that does not exist + throw new SQLFakeRuntimeException('Column with index '.$this->columnExpression.' not found in row'); + } else { + return null; + } + } + + /** + * for use in ORDER BY... allow evaluating the expression + * to fall through to the full row if the column is not found fully qualified. + */ + public function allowFallthrough(): void { + $this->allowFallthrough = true; + } + + <<__Override>> + public function isWellFormed(): bool { + return true; + } + + public function tableName(): ?string { + return $this->tableName; + } + + public function prefixColumnExpression(string $prefix): void { + if (!Str\starts_with($this->columnExpression, $prefix)) { + $this->columnExpression = $prefix.$this->columnExpression; + } + } + + <<__Override>> + public function __debugInfo(): dict { + return dict['type' => 'colref', 'name' => $this->name]; + } } diff --git a/src/Expressions/ConstantExpression.php b/src/Expressions/ConstantExpression.php index d9df7e6..a68c7e6 100644 --- a/src/Expressions/ConstantExpression.php +++ b/src/Expressions/ConstantExpression.php @@ -9,47 +9,46 @@ */ final class ConstantExpression extends Expression { - public mixed $value; - - public function __construct(token $token) { - $this->type = $token['type']; - $this->precedence = 0; - $this->name = $token['value']; - $this->value = $this->extractConstantValue($token); - } - - private function extractConstantValue(token $token): mixed { - switch ($token['type']) { - case TokenType::NUMERIC_CONSTANT: - if (Str\contains((string)$token['value'], '.')) { - return (float)$token['value']; - } - return (int)$token['value']; - case TokenType::BOOLEAN_CONSTANT: - return (bool)$token['value']; - case TokenType::STRING_CONSTANT: - return (string)$token['value']; - case TokenType::NULL_CONSTANT: - return null; - default: - throw new SQLFakeRuntimeException( - "Attempted to assign invalid token type {$token['type']} to Constant Expression", - ); - } - } - - <<__Override>> - public function evaluateImpl(row $_row, AsyncMysqlConnection $_conn): mixed { - return $this->value; - } - - <<__Override>> - public function isWellFormed(): bool { - return true; - } - - <<__Override>> - public function __debugInfo(): dict { - return dict['type' => 'const', 'name' => $this->name, 'value' => $this->value]; - } + public mixed $value; + + public function __construct(token $token) { + $this->type = $token['type']; + $this->precedence = 0; + $this->name = $token['value']; + $this->value = $this->extractConstantValue($token); + } + + private function extractConstantValue(token $token): mixed { + switch ($token['type']) { + case TokenType::NUMERIC_CONSTANT: + if (Str\contains((string)$token['value'], '.')) { + return (float)$token['value']; + } + return (int)$token['value']; + case TokenType::BOOLEAN_CONSTANT: + return (bool)$token['value']; + case TokenType::STRING_CONSTANT: + return (string)$token['value']; + case TokenType::NULL_CONSTANT: + return null; + default: + throw + new SQLFakeRuntimeException("Attempted to assign invalid token type {$token['type']} to Constant Expression"); + } + } + + <<__Override>> + public function evaluateImpl(row $_row, AsyncMysqlConnection $_conn): mixed { + return $this->value; + } + + <<__Override>> + public function isWellFormed(): bool { + return true; + } + + <<__Override>> + public function __debugInfo(): dict { + return dict['type' => 'const', 'name' => $this->name, 'value' => $this->value]; + } } diff --git a/src/Expressions/Expression.php b/src/Expressions/Expression.php index 3ffda25..378f360 100644 --- a/src/Expressions/Expression.php +++ b/src/Expressions/Expression.php @@ -3,8 +3,8 @@ namespace Slack\SQLFake; type ExpressionEvaluationOpts = shape( - ?'encode_json' => bool, - ?'bool_as_int' => bool, + ?'encode_json' => bool, + ?'bool_as_int' => bool, ); /* @@ -25,109 +25,109 @@ */ abstract class Expression { - public ?Operator $operator; - public bool $negated = false; - public int $precedence; - public string $name; - protected TokenType $type; - protected bool $evaluates_groups = false; - - /* - * many expressions won't support negation, - * and should throw a parse error if this is called - * subclasses that do support negation must override this - */ - public function negate(): void { - throw new SQLFakeParseException("Parse error: unexpected NOT for expression {$this->type}"); - } - - /** - * Expressions are built up incrementally when parsing - * This function allows an expression to signify if it has all of the required sub-expressions, - * such as having both the "left" and "right" operators for a binary expressions - */ - public abstract function isWellFormed(): bool; - - // This is not the method to override by the concrete Expression subclasses () - final public function evaluate( - row $row, - AsyncMysqlConnection $conn, - ExpressionEvaluationOpts $opts = shape('encode_json' => true, 'bool_as_int' => true), - ): mixed { - $encodeJSON = $opts['encode_json'] ?? true; - $boolAsInt = $opts['bool_as_int'] ?? true; - - $result = $this->evaluateImpl($row, $conn); - - if ($result is WrappedJSON) { - $result = $encodeJSON ? $result->asString() : $result; - } - - if ($result is bool) { - return $boolAsInt ? (int)$result : $result; - } - - return $result; - } - - /** - * a lot of times you just want the value - */ - public abstract function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed; - - /** - * when evaluating an expression in a Select, we want its name to use as the column name - */ - public function evaluateWithName(row $row, AsyncMysqlConnection $conn): (string, mixed) { - return tuple($this->name, $this->evaluate($row, $conn)); - } - - public function getType(): TokenType { - return $this->type; - } - - /** - * Only some expression types support children - * For example, set "right" on a binary, or "start" on a BETWEEN, or "when" on a case - * For several of the types that don't have child elements this is just a parse error, children who implement it can override - */ - public function setNextChild(Expression $_expr, bool $_overwrite = false): void { - throw new SQLFakeParseException('Parse error: unexpected expression'); - } - - /** - * Only some expression types support recursive expressions - * otherwise if unimplemented, it's a parse error - */ - public function addRecursiveExpression(token_list $_tokens, int $_pointer, bool $_negated = false): int { - throw new SQLFakeParseException('Parse error: unexpected recursive expression'); - } - - /** - * All operators have to handle potentially grouped data sets. - * row is a dict, but for a grouped data set the "mixed" - * will itself be a dict, so the row will be a dict> - * See applyGroupBy which does the grouping - * Since some operators don't want grouped data (column expressions and non-aggregate functions) - * this helper lets them extract the first value from the grouping set - */ - protected function maybeUnrollGroupedDataset(row $rows): row { - // as an optimization, we deliberately don't call C\first - foreach ($rows as $row) { - if ($row is dict<_, _>) { - /* HH_FIXME[4110] generics can't be specified here yet */ - return $row; - } - break; - } - - return $rows; - } - - /** - * Return a container-ish representation of an expression for pretty printing - * used for logging and debugging. - * Expressions with children SHOULD call this on all children too - */ - public abstract function __debugInfo(): KeyedContainer; + public ?Operator $operator; + public bool $negated = false; + public int $precedence; + public string $name; + protected TokenType $type; + protected bool $evaluates_groups = false; + + /* + * many expressions won't support negation, + * and should throw a parse error if this is called + * subclasses that do support negation must override this + */ + public function negate(): void { + throw new SQLFakeParseException("Parse error: unexpected NOT for expression {$this->type}"); + } + + /** + * Expressions are built up incrementally when parsing + * This function allows an expression to signify if it has all of the required sub-expressions, + * such as having both the "left" and "right" operators for a binary expressions + */ + public abstract function isWellFormed(): bool; + + // This is not the method to override by the concrete Expression subclasses () + final public function evaluate( + row $row, + AsyncMysqlConnection $conn, + ExpressionEvaluationOpts $opts = shape('encode_json' => true, 'bool_as_int' => true), + ): mixed { + $encodeJSON = $opts['encode_json'] ?? true; + $boolAsInt = $opts['bool_as_int'] ?? true; + + $result = $this->evaluateImpl($row, $conn); + + if ($result is WrappedJSON) { + $result = $encodeJSON ? $result->asString() : $result; + } + + if ($result is bool) { + return $boolAsInt ? (int)$result : $result; + } + + return $result; + } + + /** + * a lot of times you just want the value + */ + public abstract function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed; + + /** + * when evaluating an expression in a Select, we want its name to use as the column name + */ + public function evaluateWithName(row $row, AsyncMysqlConnection $conn): (string, mixed) { + return tuple($this->name, $this->evaluate($row, $conn)); + } + + public function getType(): TokenType { + return $this->type; + } + + /** + * Only some expression types support children + * For example, set "right" on a binary, or "start" on a BETWEEN, or "when" on a case + * For several of the types that don't have child elements this is just a parse error, children who implement it can override + */ + public function setNextChild(Expression $_expr, bool $_overwrite = false): void { + throw new SQLFakeParseException('Parse error: unexpected expression'); + } + + /** + * Only some expression types support recursive expressions + * otherwise if unimplemented, it's a parse error + */ + public function addRecursiveExpression(token_list $_tokens, int $_pointer, bool $_negated = false): int { + throw new SQLFakeParseException('Parse error: unexpected recursive expression'); + } + + /** + * All operators have to handle potentially grouped data sets. + * row is a dict, but for a grouped data set the "mixed" + * will itself be a dict, so the row will be a dict> + * See applyGroupBy which does the grouping + * Since some operators don't want grouped data (column expressions and non-aggregate functions) + * this helper lets them extract the first value from the grouping set + */ + protected function maybeUnrollGroupedDataset(row $rows): row { + // as an optimization, we deliberately don't call C\first + foreach ($rows as $row) { + if ($row is dict<_, _>) { + /* HH_FIXME[4110] generics can't be specified here yet */ + return $row; + } + break; + } + + return $rows; + } + + /** + * Return a container-ish representation of an expression for pretty printing + * used for logging and debugging. + * Expressions with children SHOULD call this on all children too + */ + public abstract function __debugInfo(): KeyedContainer; } diff --git a/src/Expressions/FunctionExpression.php b/src/Expressions/FunctionExpression.php index 211e8a9..9cb4fb9 100644 --- a/src/Expressions/FunctionExpression.php +++ b/src/Expressions/FunctionExpression.php @@ -10,452 +10,452 @@ */ final class FunctionExpression extends BaseFunctionExpression { - <<__Override>> - public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { - - switch ($this->functionName) { - case 'COUNT': - return $this->sqlCount($row, $conn); - case 'SUM': - return $this->sqlSum($row, $conn); - case 'MAX': - return $this->sqlMax($row, $conn); - case 'MIN': - return $this->sqlMin($row, $conn); - case 'MOD': - return $this->sqlMod($row, $conn); - case 'AVG': - return $this->sqlAvg($row, $conn); - case 'IF': - return $this->sqlIf($row, $conn); - case 'IFNULL': - case 'COALESCE': - return $this->sqlCoalesce($row, $conn); - case 'NULLIF': - return $this->sqlNullif($row, $conn); - case 'SUBSTRING': - case 'SUBSTR': - return $this->sqlSubstring($row, $conn); - case 'SUBSTRING_INDEX': - return $this->sqlSubstringIndex($row, $conn); - case 'LENGTH': - return $this->sqlLength($row, $conn); - case 'LOWER': - return $this->sqlLower($row, $conn); - case 'CHAR_LENGTH': - case 'CHARACTER_LENGTH': - return $this->sqlCharLength($row, $conn); - case 'CONCAT_WS': - return $this->sqlConcatWS($row, $conn); - case 'CONCAT': - return $this->sqlConcat($row, $conn); - case 'FIELD': - return $this->sqlField($row, $conn); - case 'BINARY': - return $this->sqlBinary($row, $conn); - case 'FROM_UNIXTIME': - return $this->sqlFromUnixtime($row, $conn); - case 'GREATEST': - return $this->sqlGreatest($row, $conn); - case 'VALUES': - return $this->sqlValues($row, $conn); - case 'REPLACE': - return $this->sqlReplace($row, $conn); - case 'UNIX_TIMESTAMP': - return $this->sqlUnixTimestamp($row, $conn); - default: - throw new SQLFakeRuntimeException('Function '.$this->functionName.' not implemented yet'); - } - } - - public function isAggregate(): bool { - return C\contains_key(keyset['COUNT', 'SUM', 'MIN', 'MAX', 'AVG'], $this->functionName); - } - - private function sqlCount(row $rows, AsyncMysqlConnection $conn): int { - $expr = $this->getExpr(); - - if ($this->distinct) { - $buckets = dict[]; - foreach ($rows as $row) { - $row as dict<_, _>; - $val = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn); - if ($val is arraykey) { - $buckets[$val] = 1; - } - } - - return C\count($buckets); - } - - $count = 0; - foreach ($rows as $row) { - // all functions are passed a row object - // but select process will pass groups of rows instead so each element should be an entire row - $row as dict<_, _>; - if ($expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn) is nonnull) { - $count++; - } - } - - return $count; - } - - private function sqlSum(row $rows, AsyncMysqlConnection $conn): num { - $expr = $this->getExpr(); - $sum = 0; - - foreach ($rows as $row) { - $row as dict<_, _>; - $val = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn); - $num = $val is int ? $val : (float)($val); - $sum += $num; - } - return $sum; - } - - private function sqlMin(row $rows, AsyncMysqlConnection $conn): mixed { - $expr = $this->getExpr(); - $values = vec[]; - foreach ($rows as $row) { - $row as dict<_, _>; - $values[] = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn); - } - - if (C\is_empty($values)) { - return null; - } - - return \min($values); - } - - private function sqlMax(row $rows, AsyncMysqlConnection $conn): mixed { - $expr = $this->getExpr(); - - $values = vec[]; - foreach ($rows as $row) { - $row as dict<_, _>; - $values[] = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn); - } - - if (C\is_empty($values)) { - return null; - } - - return \max($values); - } - - private function sqlMod(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 2) { - throw new SQLFakeRuntimeException('MySQL MOD() function must be called with two arguments'); - } - $n = $args[0]; - $n_value = (int)$n->evaluate($row, $conn); - $m = $args[1]; - $m_value = (int)$m->evaluate($row, $conn); - - return $n_value % $m_value; - } - - private function sqlAvg(row $rows, AsyncMysqlConnection $conn): mixed { - $expr = $this->getExpr(); - - $values = vec[]; - foreach ($rows as $row) { - $row as dict<_, _>; - $values[] = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn) as num; - } - - if (C\is_empty($values)) { - return null; - } - - return Math\mean($values); - } - - private function sqlIf(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 3) { - throw new SQLFakeRuntimeException('MySQL IF() function must be called with three arguments'); - } - $condition = $args[0]; - - // evaluate the ELSE condition, unless the IF condition is true - $arg_to_evaluate = 2; - if ((bool)$condition->evaluate($row, $conn)) { - $arg_to_evaluate = 1; - } - $expr = $args[$arg_to_evaluate]; - return $expr->evaluate($row, $conn); - } - - private function sqlSubstring(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 2 && C\count($args) !== 3) { - throw new SQLFakeRuntimeException('MySQL SUBSTRING() function must be called with two or three arguments'); - } - $subject = $args[0]; - $string = (string)$subject->evaluate($row, $conn); - - $position = $args[1]; - $pos = (int)$position->evaluate($row, $conn); - - // MySQL string positions 1-indexed, PHP strings are 0-indexed. So substract one from pos - $pos -= 1; - - $length = $args[2] ?? null; - if ($length !== null) { - $len = (int)$length->evaluate($row, $conn); - return \mb_substr($string, $pos, $len); - } - - return \mb_substr($string, $pos); - } - - private function sqlSubstringIndex(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 3) { - throw new SQLFakeRuntimeException('MySQL SUBSTRING_INDEX() function must be called with three arguments'); - } - $subject = $args[0]; - $string = (string)$subject->evaluate($row, $conn); - - $delimiter = $args[1]; - $delim = (string)$delimiter->evaluate($row, $conn); - - // MySQL string positions 1-indexed, PHP strings are 0-indexed. So substract one from pos - $pos = $args[2]; - if ($pos is nonnull) { - $count = (int)$pos->evaluate($row, $conn); - $parts = Str\split($string, $delim); - if ($count < 0) { - $slice = \array_slice($parts, $count); - } else { - $slice = \array_slice($parts, 0, $count); - } - - return Str\join($slice, $delim); - } - return ''; - } - - private function sqlLower(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 1) { - throw new SQLFakeRuntimeException('MySQL LOWER() function must be called with one argument'); - } - $subject = $args[0]; - $string = (string)$subject->evaluate($row, $conn); - - return Str\lowercase($string); - } - - private function sqlLength(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 1) { - throw new SQLFakeRuntimeException('MySQL LENGTH() function must be called with one argument'); - } - $subject = $args[0]; - $string = (string)$subject->evaluate($row, $conn); - - return Str\length($string); - } - - private function sqlBinary(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 1) { - throw new SQLFakeRuntimeException('MySQL BINARY() function must be called with one argument'); - } - $subject = $args[0]; - return $subject->evaluate($row, $conn); - } - - private function sqlCharLength(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 1) { - throw new SQLFakeRuntimeException('MySQL CHAR_LENGTH() function must be called with one argument'); - } - $subject = $args[0]; - $string = (string)$subject->evaluate($row, $conn); - - return \mb_strlen($string); - } - - private function sqlCoalesce(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - if (!C\count($this->args)) { - throw new SQLFakeRuntimeException('MySQL COALESCE() function must be called with at least one argument'); - } - - foreach ($this->args as $arg) { - $val = $arg->evaluate($row, $conn); - if ($val is nonnull) { - return $val; - } - } - return null; - } - - private function sqlGreatest(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - - if (C\count($args) < 2) { - throw new SQLFakeRuntimeException('MySQL GREATEST() function must be called with at two arguments'); - } - - $values = vec[]; - foreach ($this->args as $arg) { - $val = $arg->evaluate($row, $conn); - // MySQL always returns null if ANY argument to this function is null - if ($val is null) { - return null; - } - $values[] = $val; - } - - return \max($values); - } - - private function sqlNullif(row $row, AsyncMysqlConnection $conn): mixed { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 2) { - throw new SQLFakeRuntimeException('MySQL NULLIF() function must be called with two arguments'); - } - $left = $args[0]->evaluate($row, $conn); - $right = $args[1]->evaluate($row, $conn); - - return ($left === $right) ? null : $left; - } - - private function sqlFromUnixtime(row $row, AsyncMysqlConnection $conn): string { - $row = $this->maybeUnrollGroupedDataset($row); - - $args = $this->args; - if (C\count($args) !== 1) { - throw new SQLFakeRuntimeException('MySQL FROM_UNIXTIME() SQLFake only implemented for 1 argument'); - } - - $column = $args[0]->evaluate($row, $conn); - - // - // This is the default format from MySQL ‘YYYYY-MM-DD HH:MM:SS’ - // - - $format = 'Y-m-d G:i:s'; - return \date($format, (int)$column); - } - - private function sqlConcat(row $row, AsyncMysqlConnection $conn): string { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - if (C\count($args) < 2) { - throw new SQLFakeRuntimeException('MySQL CONCAT() function must be called with at least two arguments'); - } - $final_concat = ''; - foreach ($args as $arg) { - $val = (string)$arg->evaluate($row, $conn); - $final_concat .= $val; - } - return $final_concat; - } - - private function sqlConcatWS(row $row, AsyncMysqlConnection $conn): string { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - if (C\count($args) < 2) { - throw new SQLFakeRuntimeException('MySQL CONCAT_WS() function must be called with at least two arguments'); - } - $separator = $args[0]->evaluate($row, $conn); - if ($separator === null) { - throw new SQLFakeRuntimeException('MySQL CONCAT_WS() function required non null separator'); - } - $separator = (string)$separator; - $final_concat = ''; - foreach ($args as $k => $arg) { - if ($k < 1) { - continue; - } - $val = (string)$arg->evaluate($row, $conn); - if (Str\is_empty($final_concat)) { - $final_concat = $final_concat.$val; - } else { - $final_concat = $final_concat.$separator.$val; - } - } - return $final_concat; - } - - private function sqlField(row $row, AsyncMysqlConnection $conn): mixed { - $args = $this->args; - $num_args = C\count($args); - if ($num_args < 2) { - throw new SQLFakeRuntimeException('MySQL FIELD() function must be called with at least two arguments'); - } - - $value = $args[0]->evaluate($row, $conn); - foreach ($args as $k => $arg) { - if ($k < 1) { - continue; - } - if ($value == $arg->evaluate($row, $conn)) { - return $k; - } - } - return 0; - } - - private function sqlValues(row $row, AsyncMysqlConnection $conn): mixed { - $args = $this->args; - $num_args = C\count($args); - if ($num_args !== 1) { - throw new SQLFakeRuntimeException('MySQL VALUES() function must be called with one argument'); - } - - $arg = $args[0]; - if (!$arg is ColumnExpression) { - throw new SQLFakeRuntimeException('MySQL VALUES() function should be called with a column name'); - } - - // a bit hacky here, override so that the expression pulls the value from the sql_fake_values.* fields set in Query::applySet - $arg->prefixColumnExpression('sql_fake_values.'); - return $arg->evaluate($row, $conn); - } - - private function sqlReplace(row $row, AsyncMysqlConnection $conn): string { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - if (C\count($args) !== 3) { - throw new SQLFakeRuntimeException('MySQL REPLACE() function must be called with three arguments'); - } - return Str\replace_every( - (string)$args[0]->evaluate($row, $conn), - dict[(string)$args[1]->evaluate($row, $conn) => (string)$args[2]->evaluate($row, $conn)], - ); - } - - private function sqlUnixTimestamp(row $row, AsyncMysqlConnection $conn): int { + <<__Override>> + public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { + + switch ($this->functionName) { + case 'COUNT': + return $this->sqlCount($row, $conn); + case 'SUM': + return $this->sqlSum($row, $conn); + case 'MAX': + return $this->sqlMax($row, $conn); + case 'MIN': + return $this->sqlMin($row, $conn); + case 'MOD': + return $this->sqlMod($row, $conn); + case 'AVG': + return $this->sqlAvg($row, $conn); + case 'IF': + return $this->sqlIf($row, $conn); + case 'IFNULL': + case 'COALESCE': + return $this->sqlCoalesce($row, $conn); + case 'NULLIF': + return $this->sqlNullif($row, $conn); + case 'SUBSTRING': + case 'SUBSTR': + return $this->sqlSubstring($row, $conn); + case 'SUBSTRING_INDEX': + return $this->sqlSubstringIndex($row, $conn); + case 'LENGTH': + return $this->sqlLength($row, $conn); + case 'LOWER': + return $this->sqlLower($row, $conn); + case 'CHAR_LENGTH': + case 'CHARACTER_LENGTH': + return $this->sqlCharLength($row, $conn); + case 'CONCAT_WS': + return $this->sqlConcatWS($row, $conn); + case 'CONCAT': + return $this->sqlConcat($row, $conn); + case 'FIELD': + return $this->sqlField($row, $conn); + case 'BINARY': + return $this->sqlBinary($row, $conn); + case 'FROM_UNIXTIME': + return $this->sqlFromUnixtime($row, $conn); + case 'GREATEST': + return $this->sqlGreatest($row, $conn); + case 'VALUES': + return $this->sqlValues($row, $conn); + case 'REPLACE': + return $this->sqlReplace($row, $conn); + case 'UNIX_TIMESTAMP': + return $this->sqlUnixTimestamp($row, $conn); + default: + throw new SQLFakeRuntimeException('Function '.$this->functionName.' not implemented yet'); + } + } + + public function isAggregate(): bool { + return C\contains_key(keyset['COUNT', 'SUM', 'MIN', 'MAX', 'AVG'], $this->functionName); + } + + private function sqlCount(row $rows, AsyncMysqlConnection $conn): int { + $expr = $this->getExpr(); + + if ($this->distinct) { + $buckets = dict[]; + foreach ($rows as $row) { + $row as dict<_, _>; + $val = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn); + if ($val is arraykey) { + $buckets[$val] = 1; + } + } + + return C\count($buckets); + } + + $count = 0; + foreach ($rows as $row) { + // all functions are passed a row object + // but select process will pass groups of rows instead so each element should be an entire row + $row as dict<_, _>; + if ($expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn) is nonnull) { + $count++; + } + } + + return $count; + } + + private function sqlSum(row $rows, AsyncMysqlConnection $conn): num { + $expr = $this->getExpr(); + $sum = 0; + + foreach ($rows as $row) { + $row as dict<_, _>; + $val = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn); + $num = $val is int ? $val : (float)($val); + $sum += $num; + } + return $sum; + } + + private function sqlMin(row $rows, AsyncMysqlConnection $conn): mixed { + $expr = $this->getExpr(); + $values = vec[]; + foreach ($rows as $row) { + $row as dict<_, _>; + $values[] = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn); + } + + if (C\is_empty($values)) { + return null; + } + + return \min($values); + } + + private function sqlMax(row $rows, AsyncMysqlConnection $conn): mixed { + $expr = $this->getExpr(); + + $values = vec[]; + foreach ($rows as $row) { + $row as dict<_, _>; + $values[] = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn); + } + + if (C\is_empty($values)) { + return null; + } + + return \max($values); + } + + private function sqlMod(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 2) { + throw new SQLFakeRuntimeException('MySQL MOD() function must be called with two arguments'); + } + $n = $args[0]; + $n_value = (int)$n->evaluate($row, $conn); + $m = $args[1]; + $m_value = (int)$m->evaluate($row, $conn); + + return $n_value % $m_value; + } + + private function sqlAvg(row $rows, AsyncMysqlConnection $conn): mixed { + $expr = $this->getExpr(); + + $values = vec[]; + foreach ($rows as $row) { + $row as dict<_, _>; + $values[] = $expr->evaluate(/* HH_FIXME[4110] generics */ $row, $conn) as num; + } + + if (C\is_empty($values)) { + return null; + } + + return Math\mean($values); + } + + private function sqlIf(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 3) { + throw new SQLFakeRuntimeException('MySQL IF() function must be called with three arguments'); + } + $condition = $args[0]; + + // evaluate the ELSE condition, unless the IF condition is true + $arg_to_evaluate = 2; + if ((bool)$condition->evaluate($row, $conn)) { + $arg_to_evaluate = 1; + } + $expr = $args[$arg_to_evaluate]; + return $expr->evaluate($row, $conn); + } + + private function sqlSubstring(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 2 && C\count($args) !== 3) { + throw new SQLFakeRuntimeException('MySQL SUBSTRING() function must be called with two or three arguments'); + } + $subject = $args[0]; + $string = (string)$subject->evaluate($row, $conn); + + $position = $args[1]; + $pos = (int)$position->evaluate($row, $conn); + + // MySQL string positions 1-indexed, PHP strings are 0-indexed. So substract one from pos + $pos -= 1; + + $length = $args[2] ?? null; + if ($length !== null) { + $len = (int)$length->evaluate($row, $conn); + return \mb_substr($string, $pos, $len); + } + + return \mb_substr($string, $pos); + } + + private function sqlSubstringIndex(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 3) { + throw new SQLFakeRuntimeException('MySQL SUBSTRING_INDEX() function must be called with three arguments'); + } + $subject = $args[0]; + $string = (string)$subject->evaluate($row, $conn); + + $delimiter = $args[1]; + $delim = (string)$delimiter->evaluate($row, $conn); + + // MySQL string positions 1-indexed, PHP strings are 0-indexed. So substract one from pos + $pos = $args[2]; + if ($pos is nonnull) { + $count = (int)$pos->evaluate($row, $conn); + $parts = Str\split($string, $delim); + if ($count < 0) { + $slice = \array_slice($parts, $count); + } else { + $slice = \array_slice($parts, 0, $count); + } + + return Str\join($slice, $delim); + } + return ''; + } + + private function sqlLower(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 1) { + throw new SQLFakeRuntimeException('MySQL LOWER() function must be called with one argument'); + } + $subject = $args[0]; + $string = (string)$subject->evaluate($row, $conn); + + return Str\lowercase($string); + } + + private function sqlLength(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 1) { + throw new SQLFakeRuntimeException('MySQL LENGTH() function must be called with one argument'); + } + $subject = $args[0]; + $string = (string)$subject->evaluate($row, $conn); + + return Str\length($string); + } + + private function sqlBinary(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 1) { + throw new SQLFakeRuntimeException('MySQL BINARY() function must be called with one argument'); + } + $subject = $args[0]; + return $subject->evaluate($row, $conn); + } + + private function sqlCharLength(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 1) { + throw new SQLFakeRuntimeException('MySQL CHAR_LENGTH() function must be called with one argument'); + } + $subject = $args[0]; + $string = (string)$subject->evaluate($row, $conn); + + return \mb_strlen($string); + } + + private function sqlCoalesce(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + if (!C\count($this->args)) { + throw new SQLFakeRuntimeException('MySQL COALESCE() function must be called with at least one argument'); + } + + foreach ($this->args as $arg) { + $val = $arg->evaluate($row, $conn); + if ($val is nonnull) { + return $val; + } + } + return null; + } + + private function sqlGreatest(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + + if (C\count($args) < 2) { + throw new SQLFakeRuntimeException('MySQL GREATEST() function must be called with at two arguments'); + } + + $values = vec[]; + foreach ($this->args as $arg) { + $val = $arg->evaluate($row, $conn); + // MySQL always returns null if ANY argument to this function is null + if ($val is null) { + return null; + } + $values[] = $val; + } + + return \max($values); + } + + private function sqlNullif(row $row, AsyncMysqlConnection $conn): mixed { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 2) { + throw new SQLFakeRuntimeException('MySQL NULLIF() function must be called with two arguments'); + } + $left = $args[0]->evaluate($row, $conn); + $right = $args[1]->evaluate($row, $conn); + + return ($left === $right) ? null : $left; + } + + private function sqlFromUnixtime(row $row, AsyncMysqlConnection $conn): string { + $row = $this->maybeUnrollGroupedDataset($row); + + $args = $this->args; + if (C\count($args) !== 1) { + throw new SQLFakeRuntimeException('MySQL FROM_UNIXTIME() SQLFake only implemented for 1 argument'); + } + + $column = $args[0]->evaluate($row, $conn); + + // + // This is the default format from MySQL ‘YYYYY-MM-DD HH:MM:SS’ + // + + $format = 'Y-m-d G:i:s'; + return \date($format, (int)$column); + } + + private function sqlConcat(row $row, AsyncMysqlConnection $conn): string { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + if (C\count($args) < 2) { + throw new SQLFakeRuntimeException('MySQL CONCAT() function must be called with at least two arguments'); + } + $final_concat = ''; + foreach ($args as $arg) { + $val = (string)$arg->evaluate($row, $conn); + $final_concat .= $val; + } + return $final_concat; + } + + private function sqlConcatWS(row $row, AsyncMysqlConnection $conn): string { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + if (C\count($args) < 2) { + throw new SQLFakeRuntimeException('MySQL CONCAT_WS() function must be called with at least two arguments'); + } + $separator = $args[0]->evaluate($row, $conn); + if ($separator === null) { + throw new SQLFakeRuntimeException('MySQL CONCAT_WS() function required non null separator'); + } + $separator = (string)$separator; + $final_concat = ''; + foreach ($args as $k => $arg) { + if ($k < 1) { + continue; + } + $val = (string)$arg->evaluate($row, $conn); + if (Str\is_empty($final_concat)) { + $final_concat = $final_concat.$val; + } else { + $final_concat = $final_concat.$separator.$val; + } + } + return $final_concat; + } + + private function sqlField(row $row, AsyncMysqlConnection $conn): mixed { + $args = $this->args; + $num_args = C\count($args); + if ($num_args < 2) { + throw new SQLFakeRuntimeException('MySQL FIELD() function must be called with at least two arguments'); + } + + $value = $args[0]->evaluate($row, $conn); + foreach ($args as $k => $arg) { + if ($k < 1) { + continue; + } + if ($value == $arg->evaluate($row, $conn)) { + return $k; + } + } + return 0; + } + + private function sqlValues(row $row, AsyncMysqlConnection $conn): mixed { + $args = $this->args; + $num_args = C\count($args); + if ($num_args !== 1) { + throw new SQLFakeRuntimeException('MySQL VALUES() function must be called with one argument'); + } + + $arg = $args[0]; + if (!$arg is ColumnExpression) { + throw new SQLFakeRuntimeException('MySQL VALUES() function should be called with a column name'); + } + + // a bit hacky here, override so that the expression pulls the value from the sql_fake_values.* fields set in Query::applySet + $arg->prefixColumnExpression('sql_fake_values.'); + return $arg->evaluate($row, $conn); + } + + private function sqlReplace(row $row, AsyncMysqlConnection $conn): string { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + if (C\count($args) !== 3) { + throw new SQLFakeRuntimeException('MySQL REPLACE() function must be called with three arguments'); + } + return Str\replace_every( + (string)$args[0]->evaluate($row, $conn), + dict[(string)$args[1]->evaluate($row, $conn) => (string)$args[2]->evaluate($row, $conn)], + ); + } + + private function sqlUnixTimestamp(row $row, AsyncMysqlConnection $conn): int { $row = $this->maybeUnrollGroupedDataset($row); $args = $this->args; diff --git a/src/Expressions/InOperatorExpression.php b/src/Expressions/InOperatorExpression.php index 95a6bc3..57e24eb 100644 --- a/src/Expressions/InOperatorExpression.php +++ b/src/Expressions/InOperatorExpression.php @@ -9,103 +9,103 @@ */ final class InOperatorExpression extends Expression { - private ?vec $inList = null; + private ?vec $inList = null; - public function __construct(private Expression $left, public bool $negated = false) { - $op = Operator::IN; - $this->name = ''; - $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($op)]; - $this->operator = $op; - $this->type = TokenType::OPERATOR; - } + public function __construct(private Expression $left, public bool $negated = false) { + $op = Operator::IN; + $this->name = ''; + $this->precedence = ExpressionParser::OPERATOR_PRECEDENCE[operator_to_string($op)]; + $this->operator = $op; + $this->type = TokenType::OPERATOR; + } - <<__Override>> - public function evaluateImpl(row $row, AsyncMysqlConnection $conn): bool { - $inList = $this->inList; - if ($inList === null || C\count($inList) === 0) { - throw new SQLFakeParseException('Parse error: empty IN list'); - } + <<__Override>> + public function evaluateImpl(row $row, AsyncMysqlConnection $conn): bool { + $inList = $this->inList; + if ($inList === null || C\count($inList) === 0) { + throw new SQLFakeParseException('Parse error: empty IN list'); + } - // - // Handle NULL as a special case: MySQL evaluates both "IN (NULL)" and "NOT IN (NULL)" to false, - // but while "IN (NULL)" might make sense when running a query with an empty IN list, - // "NOT IN (NULL)" almost certainly doesn't match what the developer is expecting. - // To avoid confusion, we just throw an SQLFakeException here. - // - if (C\count($inList) === 1 && $inList[0]->evaluate($row, $conn) === null) { - if (!$this->negated) { - return false; - } else { - throw new SQLFakeRuntimeException( - "You're probably trying to use NOT IN with an empty array, but MySQL would evaluate this to false.", - ); - } - } + // + // Handle NULL as a special case: MySQL evaluates both "IN (NULL)" and "NOT IN (NULL)" to false, + // but while "IN (NULL)" might make sense when running a query with an empty IN list, + // "NOT IN (NULL)" almost certainly doesn't match what the developer is expecting. + // To avoid confusion, we just throw an SQLFakeException here. + // + if (C\count($inList) === 1 && $inList[0]->evaluate($row, $conn) === null) { + if (!$this->negated) { + return false; + } else { + throw new SQLFakeRuntimeException( + "You're probably trying to use NOT IN with an empty array, but MySQL would evaluate this to false.", + ); + } + } - $value = $this->left->evaluate($row, $conn); - foreach ($inList as $in_expr) { - // found it? return the opposite of "negated". so if negated is false, return true. - // if it's a subquery, we have to iterate over the results and extract the field from each row - if ($in_expr is SubqueryExpression) { - $ret = $in_expr->evaluate($row, $conn) as KeyedContainer<_, _>; - foreach ($ret as $r) { - $r as KeyedContainer<_, _>; - if (C\count($r) !== 1) { - throw new SQLFakeRuntimeException('Subquery result should contain 1 column'); - } - foreach ($r as $val) { - if ($value == $val) { - return !$this->negated; - } - } - } - } else { - if (\HH\Lib\Legacy_FIXME\eq($value, $in_expr->evaluate($row, $conn))) { - return !$this->negated; - } - } - } + $value = $this->left->evaluate($row, $conn); + foreach ($inList as $in_expr) { + // found it? return the opposite of "negated". so if negated is false, return true. + // if it's a subquery, we have to iterate over the results and extract the field from each row + if ($in_expr is SubqueryExpression) { + $ret = $in_expr->evaluate($row, $conn) as KeyedContainer<_, _>; + foreach ($ret as $r) { + $r as KeyedContainer<_, _>; + if (C\count($r) !== 1) { + throw new SQLFakeRuntimeException('Subquery result should contain 1 column'); + } + foreach ($r as $val) { + if ($value == $val) { + return !$this->negated; + } + } + } + } else { + if (\HH\Lib\Legacy_FIXME\eq($value, $in_expr->evaluate($row, $conn))) { + return !$this->negated; + } + } + } - return $this->negated; - } + return $this->negated; + } - <<__Override>> - public function negate(): void { - $this->negated = true; - } + <<__Override>> + public function negate(): void { + $this->negated = true; + } - <<__Override>> - public function isWellFormed(): bool { - return $this->inList !== null; - } + <<__Override>> + public function isWellFormed(): bool { + return $this->inList !== null; + } - public function setInList(vec $list): void { - $this->inList = $list; - } + public function setInList(vec $list): void { + $this->inList = $list; + } - <<__Override>> - public function setNextChild(Expression $expr, bool $_overwrite = false): void { - $this->inList = vec[$expr]; - } + <<__Override>> + public function setNextChild(Expression $expr, bool $_overwrite = false): void { + $this->inList = vec[$expr]; + } - <<__Override>> - public function __debugInfo(): dict { - $inList = vec[]; - if ($this->inList !== null) { - foreach ($this->inList as $expr) { - $inList[] = \var_dump($expr, true); - } - } - $ret = dict[ - 'type' => 'IN', - 'left' => \var_dump($this->left, true), - 'in' => $inList, - 'negated' => $this->negated, - ]; + <<__Override>> + public function __debugInfo(): dict { + $inList = vec[]; + if ($this->inList !== null) { + foreach ($this->inList as $expr) { + $inList[] = \var_dump($expr, true); + } + } + $ret = dict[ + 'type' => 'IN', + 'left' => \var_dump($this->left, true), + 'in' => $inList, + 'negated' => $this->negated, + ]; - if (!Str\is_empty($this->name)) { - $ret['name'] = $this->name; - } - return $ret; - } + if (!Str\is_empty($this->name)) { + $ret['name'] = $this->name; + } + return $ret; + } } diff --git a/src/Expressions/JSONFunctionExpression.hack b/src/Expressions/JSONFunctionExpression.hack index 77bb4da..fb47f14 100644 --- a/src/Expressions/JSONFunctionExpression.hack +++ b/src/Expressions/JSONFunctionExpression.hack @@ -8,445 +8,445 @@ use namespace Slack\SQLFake\JSONPath; * we implement as many as we want to in Hack */ final class JSONFunctionExpression extends BaseFunctionExpression { - const ExpressionEvaluationOpts RETAIN_JSON_EVAL_OPTS = shape( - 'encode_json' => false, - ); - - const ExpressionEvaluationOpts RETAIN_ALL_EVAL_OPTS = shape( - 'encode_json' => false, - 'bool_as_int' => false, - ); - - const JSONPath\GetOptions UNWRAP_JSON_PATH_RESULTS = shape('unwrap' => true); - - <<__Override>> - public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { - switch ($this->functionName) { - case 'JSON_VALID': - return $this->sqlJSONValid($row, $conn); - case 'JSON_QUOTE': - return $this->sqlJSONQuote($row, $conn); - case 'JSON_UNQUOTE': - return $this->sqlJSONUnquote($row, $conn); - case 'JSON_EXTRACT': - return $this->sqlJSONExtract($row, $conn); - case 'JSON_REPLACE': - return $this->sqlJSONReplace($row, $conn); - case 'JSON_KEYS': - return $this->sqlJSONKeys($row, $conn); - case 'JSON_LENGTH': - return $this->sqlJSONLength($row, $conn); - case 'JSON_DEPTH': - return $this->sqlJSONDepth($row, $conn); - case 'JSON_CONTAINS': - return $this->sqlJSONContains($row, $conn); - } - - throw new SQLFakeRuntimeException('Function '.$this->functionName.' not implemented yet'); - } - - private function sqlJSONValid(row $row, AsyncMysqlConnection $conn): ?bool { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - if (C\count($args) !== 1) { - throw new SQLFakeRuntimeException('MySQL JSON_VALID() function must be called with one argument'); - } - - $value = $args[0]->evaluate($row, $conn); - if ($value is null) { - return null; - } - - if (!($value is string)) { - return false; - } - - $value = Str\trim($value); - if ($value !== 'null' && \json_decode($value, true, 512, \JSON_FB_HACK_ARRAYS) is null) { - return false; - } - - return true; - } - - private function sqlJSONQuote(row $row, AsyncMysqlConnection $conn): ?string { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - if (C\count($args) !== 1) { - throw new SQLFakeRuntimeException('MySQL JSON_QUOTE() function must be called with one argument'); - } - - $value = $args[0]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS); - if ($value is null) { - return null; - } - - if (!($value is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_QUOTE() function received invalid argument'); - } - - return \json_encode($value, \JSON_UNESCAPED_UNICODE); - } - - private function sqlJSONUnquote(row $row, AsyncMysqlConnection $conn): ?string { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - if (C\count($args) !== 1) { - throw new SQLFakeRuntimeException('MySQL JSON_UNQUOTE() function must be called with one argument'); - } - - $value = $args[0]->evaluate($row, $conn); - if ($value is null) { - return null; - } - - if (!($value is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_UNQUOTE() function received non string argument'); - } - - // If it begins & ends with ", it must be a valid JSON string literal so use json_decode to validate - // + decode - if ($value is string && Str\starts_with($value, '"') && Str\ends_with($value, '"')) { - $unquoted = \json_decode($value); - if ($unquoted is null || !($unquoted is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_UNQUOTE() function received invalid argument'); - } - return (string)$unquoted; - } - - // MySQL doesn't seem to do anything at all if the string doesn't start & end with " - return $value; - } - - // Returns null, num or WrappedJSON - private function sqlJSONExtract(row $row, AsyncMysqlConnection $conn): ?WrappedJSON { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - if (C\count($args) < 2) { - throw new SQLFakeRuntimeException( - 'MySQL JSON_EXTRACT() function must be called with 1 JSON document & at least 1 JSON path', - ); - } - - $json = $args[0]->evaluate($row, $conn); - if ($json is null) { - return null; - } - if (!($json is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_EXTRACT() function doc has incorrect type'); - } - - $jsonPathWithNulls = Vec\map( - Vec\slice($args, 1), - $a ==> { - $evaled = $a->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS); - if ($evaled is nonnull && !($evaled is string)) { - throw new SQLFakeRuntimeException( - 'MySQL JSON_EXTRACT() function encountered non string JSON path argument', - ); - } - return $evaled; - }, - ); - - $jsonPaths = Vec\filter_nulls($jsonPathWithNulls); - if (C\count($jsonPaths) !== C\count($jsonPathWithNulls)) { - return null; - } - - $results = vec[]; - try { - $jsonObject = new JSONPath\JSONObject($json); - - if (C\count($jsonPaths) === 1) { - // This is the only case, we can return the raw value instead of wrapping in vec[] - $result = $jsonObject->get($jsonPaths[0], self::UNWRAP_JSON_PATH_RESULTS); - if ($result is null) { - return null; - } - return new WrappedJSON($result->value); - } - - $results = C\reduce( - $jsonPaths, - ($acc, $path) ==> { - $result = $jsonObject->get($path, self::UNWRAP_JSON_PATH_RESULTS); - if ($result is null) { - return $acc; - } - $result = ($result->value is vec<_>) ? $result->value : vec[$result->value]; - return Vec\concat($acc, $result); - }, - vec[], - ); - } catch (JSONPath\JSONException $e) { - throw new SQLFakeRuntimeException('MySQL JSON_EXTRACT() function encountered error: '.$e->getMessage()); - } - - if (C\is_empty($results)) { - return null; - } - - return new WrappedJSON($results); - } - - private function sqlJSONReplace(row $row, AsyncMysqlConnection $conn): ?WrappedJSON { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - $arg_count = C\count($args); - if ($arg_count < 3 || $arg_count % 2 !== 1) { - throw new SQLFakeRuntimeException( - 'MySQL JSON_REPLACE() function must be called with 1 JSON document & at least 1 JSON path + replacement value pair ', - ); - } - - $json = $args[0]->evaluate($row, $conn); - if ($json is null) { - return null; - } - if (!($json is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_EXTRACT() function doc has incorrect type'); - } - - $replacementsWithNulls = Vec\slice($args, 1) - |> Vec\map($$, $a ==> $a->evaluate($row, $conn, self::RETAIN_ALL_EVAL_OPTS)) - |> Vec\chunk($$, 2) - |> Vec\map($$, $v ==> { - $path = $v[0]; - if ($path is nonnull && !($path is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_REPLACE() function encountered non string JSON path'); - } - - $value = $v[1]; - if ($value is WrappedJSON) { - $value = $value->rawValue(); - } - return shape('path' => $path, 'value' => $value); - }); - - $replacements = vec[]; - // Doing it this way to please typechecker that replacements doesn't have any NULL path - foreach ($replacementsWithNulls as $replacement) { - $path = $replacement['path']; - if ($path is nonnull) { - $replacements[] = shape('path' => $path, 'value' => $replacement['value']); - } - } - if (C\count($replacements) !== C\count($replacementsWithNulls)) { - return null; - } - - try { - $current = new JSONPath\JSONObject($json); - - foreach ($replacements as $replacement) { - $current = $current->replace($replacement['path'], $replacement['value'])->value; - } - - return new WrappedJSON($current->getValue()); - } catch (JSONPath\JSONException $e) { - throw new SQLFakeRuntimeException('MySQL JSON_REPLACE() function encountered error: '.$e->getMessage()); - } - } - - private function sqlJSONKeys(row $row, AsyncMysqlConnection $conn): ?WrappedJSON { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - - $argCount = C\count($args); - if ($argCount < 1) { - throw new SQLFakeRuntimeException( - 'MySQL JSON_KEYS() function must be called with at least 1 JSON document & optionally 1 JSON path', - ); - } - - $json = $args[0]->evaluate($row, $conn); - if ($json is null) { - return null; - } - if (!($json is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_KEYS() function doc has incorrect type'); - } - - $path = $argCount > 1 ? $args[1]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS) : '$'; - if ($path is null) { - return null; - } - if (!($path is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_KEYS() function path has incorrect type'); - } - - try { - $keys = (new JSONPath\JSONObject($json))->keys($path); - if ($keys is null) { - return null; - } - return new WrappedJSON($keys->value); - } catch (JSONPath\JSONException $e) { - throw new SQLFakeRuntimeException('MySQL JSON_KEYS() function encountered error: '.$e->getMessage()); - } - } - - private function sqlJSONLength(row $row, AsyncMysqlConnection $conn): ?int { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - - $argCount = C\count($args); - if ($argCount < 1) { - throw new SQLFakeRuntimeException( - 'MySQL JSON_LENGTH() function must be called with at least 1 JSON document & optionally 1 JSON path', - ); - } - - $json = $args[0]->evaluate($row, $conn); - if ($json is null) { - return null; - } - if (!($json is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_LENGTH() function doc has incorrect type'); - } - - $path = $argCount > 1 ? $args[1]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS) : '$'; - if ($path is null) { - return null; - } - if (!($path is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_LENGTH() function path has incorrect type'); - } - - try { - $keys = (new JSONPath\JSONObject($json))->length($path); - if ($keys is null) { - return null; - } - return $keys->value; - } catch (JSONPath\JSONException $e) { - throw new SQLFakeRuntimeException('MySQL JSON_LENGTH() function encountered error: '.$e->getMessage()); - } - } - - private function sqlJSONDepth(row $row, AsyncMysqlConnection $conn): ?int { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - - $argCount = C\count($args); - if ($argCount !== 1) { - throw new SQLFakeRuntimeException( - 'MySQL JSON_DEPTH() function must be called with 1 JSON document argument', - ); - } - - $json = $args[0]->evaluate($row, $conn); - if ($json is null) { - return null; - } - if (!($json is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_DEPTH() function doc has incorrect type'); - } - - try { - return (new JSONPath\JSONObject($json))->depth()->value; - } catch (JSONPath\JSONException $e) { - throw new SQLFakeRuntimeException('MySQL JSON_DEPTH() function encountered error: '.$e->getMessage()); - } - } - - private function sqlJSONContains(row $row, AsyncMysqlConnection $conn): ?bool { - $row = $this->maybeUnrollGroupedDataset($row); - $args = $this->args; - $argCount = C\count($args); - - if ($argCount !== 2 && $argCount !== 3) { - throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function must be called with 2 - 3 arguments'); - } - - // Get the json from the column - $json = $args[0]->evaluate($row, $conn); - if ($json is null) { - return null; - } - - if (!($json is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function doc has incorrect type'); - } - - $path = '$'; // Set path to root initially in case no path is given - if ($argCount === 3) { - $path = $args[2]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS); - - if (!($path is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function path has incorrect type'); - } - } - - // Narrow down the json to the specified path - try { - $json = (new JSONPath\JSONObject($json))->get($path); - if ($json is null || $json->value is null || !($json->value is vec<_>)) { - throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function given invalid json'); - } - $json = $json->value[0]; - } catch (JSONPath\JSONException $e) { - throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function encountered error: '.$e->getMessage()); - } - - // Now check if the json contains the term - try { - $term = $args[1]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS); - - if ($term is null) { - return null; - } - - if (!($term is string)) { - throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function value has incorrect type'); - } - - $term = (new JSONPath\JSONObject($term))->get('$'); - if ($term is null || $term->value is null || !($term->value is vec<_>)) { - throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function given invalid json'); - } - $term = $term->value[0]; - - if ($json is vec<_>) { - // If $json is a vec then we have an array and will test if the array contains the given value - if ($term is dict<_,_>) { - return C\count(Vec\filter($json, $val ==> { - if ($val is dict<_,_>) { - return Dict\equal($val, $term); - } - return false; - })) > 0; - } - else { - return C\contains($json, $term); - } - } - else if ($json is dict<_,_>) { - // If $json is a dict then we have an object and will test that either (1) $json and $term are the same or - // (2) one of $json's members is the same as $term - if ($term is dict<_,_>) { - if (Dict\equal($json, $term)) { return true; } - - return C\count(Dict\filter($json, $val ==> { - if ($val is dict<_,_>) { - return Dict\equal($val, $term); - } - return false; - })) > 0; - } - else { - return C\count(Dict\filter($json, $val ==> $term == $val)) > 0; - } - } - else { - return $json == $term; - } - - } catch (JSONPath\JSONException $e) { - throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function encountered error: '.$e->getMessage()); - } - - return false; - } + const ExpressionEvaluationOpts RETAIN_JSON_EVAL_OPTS = shape( + 'encode_json' => false, + ); + + const ExpressionEvaluationOpts RETAIN_ALL_EVAL_OPTS = shape( + 'encode_json' => false, + 'bool_as_int' => false, + ); + + const JSONPath\GetOptions UNWRAP_JSON_PATH_RESULTS = shape('unwrap' => true); + + <<__Override>> + public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { + switch ($this->functionName) { + case 'JSON_VALID': + return $this->sqlJSONValid($row, $conn); + case 'JSON_QUOTE': + return $this->sqlJSONQuote($row, $conn); + case 'JSON_UNQUOTE': + return $this->sqlJSONUnquote($row, $conn); + case 'JSON_EXTRACT': + return $this->sqlJSONExtract($row, $conn); + case 'JSON_REPLACE': + return $this->sqlJSONReplace($row, $conn); + case 'JSON_KEYS': + return $this->sqlJSONKeys($row, $conn); + case 'JSON_LENGTH': + return $this->sqlJSONLength($row, $conn); + case 'JSON_DEPTH': + return $this->sqlJSONDepth($row, $conn); + case 'JSON_CONTAINS': + return $this->sqlJSONContains($row, $conn); + } + + throw new SQLFakeRuntimeException('Function '.$this->functionName.' not implemented yet'); + } + + private function sqlJSONValid(row $row, AsyncMysqlConnection $conn): ?bool { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + if (C\count($args) !== 1) { + throw new SQLFakeRuntimeException('MySQL JSON_VALID() function must be called with one argument'); + } + + $value = $args[0]->evaluate($row, $conn); + if ($value is null) { + return null; + } + + if (!($value is string)) { + return false; + } + + $value = Str\trim($value); + if ($value !== 'null' && \json_decode($value, true, 512, \JSON_FB_HACK_ARRAYS) is null) { + return false; + } + + return true; + } + + private function sqlJSONQuote(row $row, AsyncMysqlConnection $conn): ?string { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + if (C\count($args) !== 1) { + throw new SQLFakeRuntimeException('MySQL JSON_QUOTE() function must be called with one argument'); + } + + $value = $args[0]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS); + if ($value is null) { + return null; + } + + if (!($value is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_QUOTE() function received invalid argument'); + } + + return \json_encode($value, \JSON_UNESCAPED_UNICODE); + } + + private function sqlJSONUnquote(row $row, AsyncMysqlConnection $conn): ?string { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + if (C\count($args) !== 1) { + throw new SQLFakeRuntimeException('MySQL JSON_UNQUOTE() function must be called with one argument'); + } + + $value = $args[0]->evaluate($row, $conn); + if ($value is null) { + return null; + } + + if (!($value is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_UNQUOTE() function received non string argument'); + } + + // If it begins & ends with ", it must be a valid JSON string literal so use json_decode to validate + // + decode + if ($value is string && Str\starts_with($value, '"') && Str\ends_with($value, '"')) { + $unquoted = \json_decode($value); + if ($unquoted is null || !($unquoted is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_UNQUOTE() function received invalid argument'); + } + return (string)$unquoted; + } + + // MySQL doesn't seem to do anything at all if the string doesn't start & end with " + return $value; + } + + // Returns null, num or WrappedJSON + private function sqlJSONExtract(row $row, AsyncMysqlConnection $conn): ?WrappedJSON { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + if (C\count($args) < 2) { + throw new SQLFakeRuntimeException( + 'MySQL JSON_EXTRACT() function must be called with 1 JSON document & at least 1 JSON path', + ); + } + + $json = $args[0]->evaluate($row, $conn); + if ($json is null) { + return null; + } + if (!($json is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_EXTRACT() function doc has incorrect type'); + } + + $jsonPathWithNulls = Vec\map( + Vec\slice($args, 1), + $a ==> { + $evaled = $a->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS); + if ($evaled is nonnull && !($evaled is string)) { + throw new SQLFakeRuntimeException( + 'MySQL JSON_EXTRACT() function encountered non string JSON path argument', + ); + } + return $evaled; + }, + ); + + $jsonPaths = Vec\filter_nulls($jsonPathWithNulls); + if (C\count($jsonPaths) !== C\count($jsonPathWithNulls)) { + return null; + } + + $results = vec[]; + try { + $jsonObject = new JSONPath\JSONObject($json); + + if (C\count($jsonPaths) === 1) { + // This is the only case, we can return the raw value instead of wrapping in vec[] + $result = $jsonObject->get($jsonPaths[0], self::UNWRAP_JSON_PATH_RESULTS); + if ($result is null) { + return null; + } + return new WrappedJSON($result->value); + } + + $results = C\reduce( + $jsonPaths, + ($acc, $path) ==> { + $result = $jsonObject->get($path, self::UNWRAP_JSON_PATH_RESULTS); + if ($result is null) { + return $acc; + } + $result = ($result->value is vec<_>) ? $result->value : vec[$result->value]; + return Vec\concat($acc, $result); + }, + vec[], + ); + } catch (JSONPath\JSONException $e) { + throw new SQLFakeRuntimeException('MySQL JSON_EXTRACT() function encountered error: '.$e->getMessage()); + } + + if (C\is_empty($results)) { + return null; + } + + return new WrappedJSON($results); + } + + private function sqlJSONReplace(row $row, AsyncMysqlConnection $conn): ?WrappedJSON { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + $arg_count = C\count($args); + if ($arg_count < 3 || $arg_count % 2 !== 1) { + throw new SQLFakeRuntimeException( + 'MySQL JSON_REPLACE() function must be called with 1 JSON document & at least 1 JSON path + replacement value pair ', + ); + } + + $json = $args[0]->evaluate($row, $conn); + if ($json is null) { + return null; + } + if (!($json is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_EXTRACT() function doc has incorrect type'); + } + + $replacementsWithNulls = Vec\slice($args, 1) + |> Vec\map($$, $a ==> $a->evaluate($row, $conn, self::RETAIN_ALL_EVAL_OPTS)) + |> Vec\chunk($$, 2) + |> Vec\map($$, $v ==> { + $path = $v[0]; + if ($path is nonnull && !($path is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_REPLACE() function encountered non string JSON path'); + } + + $value = $v[1]; + if ($value is WrappedJSON) { + $value = $value->rawValue(); + } + return shape('path' => $path, 'value' => $value); + }); + + $replacements = vec[]; + // Doing it this way to please typechecker that replacements doesn't have any NULL path + foreach ($replacementsWithNulls as $replacement) { + $path = $replacement['path']; + if ($path is nonnull) { + $replacements[] = shape('path' => $path, 'value' => $replacement['value']); + } + } + if (C\count($replacements) !== C\count($replacementsWithNulls)) { + return null; + } + + try { + $current = new JSONPath\JSONObject($json); + + foreach ($replacements as $replacement) { + $current = $current->replace($replacement['path'], $replacement['value'])->value; + } + + return new WrappedJSON($current->getValue()); + } catch (JSONPath\JSONException $e) { + throw new SQLFakeRuntimeException('MySQL JSON_REPLACE() function encountered error: '.$e->getMessage()); + } + } + + private function sqlJSONKeys(row $row, AsyncMysqlConnection $conn): ?WrappedJSON { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + + $argCount = C\count($args); + if ($argCount < 1) { + throw new SQLFakeRuntimeException( + 'MySQL JSON_KEYS() function must be called with at least 1 JSON document & optionally 1 JSON path', + ); + } + + $json = $args[0]->evaluate($row, $conn); + if ($json is null) { + return null; + } + if (!($json is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_KEYS() function doc has incorrect type'); + } + + $path = $argCount > 1 ? $args[1]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS) : '$'; + if ($path is null) { + return null; + } + if (!($path is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_KEYS() function path has incorrect type'); + } + + try { + $keys = (new JSONPath\JSONObject($json))->keys($path); + if ($keys is null) { + return null; + } + return new WrappedJSON($keys->value); + } catch (JSONPath\JSONException $e) { + throw new SQLFakeRuntimeException('MySQL JSON_KEYS() function encountered error: '.$e->getMessage()); + } + } + + private function sqlJSONLength(row $row, AsyncMysqlConnection $conn): ?int { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + + $argCount = C\count($args); + if ($argCount < 1) { + throw new SQLFakeRuntimeException( + 'MySQL JSON_LENGTH() function must be called with at least 1 JSON document & optionally 1 JSON path', + ); + } + + $json = $args[0]->evaluate($row, $conn); + if ($json is null) { + return null; + } + if (!($json is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_LENGTH() function doc has incorrect type'); + } + + $path = $argCount > 1 ? $args[1]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS) : '$'; + if ($path is null) { + return null; + } + if (!($path is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_LENGTH() function path has incorrect type'); + } + + try { + $keys = (new JSONPath\JSONObject($json))->length($path); + if ($keys is null) { + return null; + } + return $keys->value; + } catch (JSONPath\JSONException $e) { + throw new SQLFakeRuntimeException('MySQL JSON_LENGTH() function encountered error: '.$e->getMessage()); + } + } + + private function sqlJSONDepth(row $row, AsyncMysqlConnection $conn): ?int { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + + $argCount = C\count($args); + if ($argCount !== 1) { + throw new SQLFakeRuntimeException( + 'MySQL JSON_DEPTH() function must be called with 1 JSON document argument', + ); + } + + $json = $args[0]->evaluate($row, $conn); + if ($json is null) { + return null; + } + if (!($json is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_DEPTH() function doc has incorrect type'); + } + + try { + return (new JSONPath\JSONObject($json))->depth()->value; + } catch (JSONPath\JSONException $e) { + throw new SQLFakeRuntimeException('MySQL JSON_DEPTH() function encountered error: '.$e->getMessage()); + } + } + + private function sqlJSONContains(row $row, AsyncMysqlConnection $conn): ?bool { + $row = $this->maybeUnrollGroupedDataset($row); + $args = $this->args; + $argCount = C\count($args); + + if ($argCount !== 2 && $argCount !== 3) { + throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function must be called with 2 - 3 arguments'); + } + + // Get the json from the column + $json = $args[0]->evaluate($row, $conn); + if ($json is null) { + return null; + } + + if (!($json is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function doc has incorrect type'); + } + + $path = '$'; // Set path to root initially in case no path is given + if ($argCount === 3) { + $path = $args[2]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS); + + if (!($path is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function path has incorrect type'); + } + } + + // Narrow down the json to the specified path + try { + $json = (new JSONPath\JSONObject($json))->get($path); + if ($json is null || $json->value is null || !($json->value is vec<_>)) { + throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function given invalid json'); + } + $json = $json->value[0]; + } catch (JSONPath\JSONException $e) { + throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function encountered error: '.$e->getMessage()); + } + + // Now check if the json contains the term + try { + $term = $args[1]->evaluate($row, $conn, self::RETAIN_JSON_EVAL_OPTS); + + if ($term is null) { + return null; + } + + if (!($term is string)) { + throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function value has incorrect type'); + } + + $term = (new JSONPath\JSONObject($term))->get('$'); + if ($term is null || $term->value is null || !($term->value is vec<_>)) { + throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function given invalid json'); + } + $term = $term->value[0]; + + if ($json is vec<_>) { + // If $json is a vec then we have an array and will test if the array contains the given value + if ($term is dict<_,_>) { + return C\count(Vec\filter($json, $val ==> { + if ($val is dict<_,_>) { + return Dict\equal($val, $term); + } + return false; + })) > 0; + } + else { + return C\contains($json, $term); + } + } + else if ($json is dict<_,_>) { + // If $json is a dict then we have an object and will test that either (1) $json and $term are the same or + // (2) one of $json's members is the same as $term + if ($term is dict<_,_>) { + if (Dict\equal($json, $term)) { return true; } + + return C\count(Dict\filter($json, $val ==> { + if ($val is dict<_,_>) { + return Dict\equal($val, $term); + } + return false; + })) > 0; + } + else { + return C\count(Dict\filter($json, $val ==> $term == $val)) > 0; + } + } + else { + return $json == $term; + } + + } catch (JSONPath\JSONException $e) { + throw new SQLFakeRuntimeException('MySQL JSON_CONTAINS() function encountered error: '.$e->getMessage()); + } + + return false; + } } diff --git a/src/Expressions/PlaceholderExpression.php b/src/Expressions/PlaceholderExpression.php index a255faa..b403bee 100644 --- a/src/Expressions/PlaceholderExpression.php +++ b/src/Expressions/PlaceholderExpression.php @@ -9,24 +9,24 @@ */ final class PlaceholderExpression extends Expression { - public function __construct() { - $this->precedence = 0; - $this->name = ''; - $this->type = TokenType::RESERVED; - } + public function __construct() { + $this->precedence = 0; + $this->name = ''; + $this->type = TokenType::RESERVED; + } - <<__Override>> - public function evaluateImpl(row $_row, AsyncMysqlConnection $_conn): mixed { - throw new SQLFakeRuntimeException('Attempted to evaluate placeholder expression!'); - } + <<__Override>> + public function evaluateImpl(row $_row, AsyncMysqlConnection $_conn): mixed { + throw new SQLFakeRuntimeException('Attempted to evaluate placeholder expression!'); + } - <<__Override>> - public function isWellFormed(): bool { - return false; - } + <<__Override>> + public function isWellFormed(): bool { + return false; + } - <<__Override>> - public function __debugInfo(): dict { - return dict['type' => 'placeholder']; - } + <<__Override>> + public function __debugInfo(): dict { + return dict['type' => 'placeholder']; + } } diff --git a/src/Expressions/RowExpression.php b/src/Expressions/RowExpression.php index 0fa5793..351b55d 100644 --- a/src/Expressions/RowExpression.php +++ b/src/Expressions/RowExpression.php @@ -8,34 +8,34 @@ */ final class RowExpression extends Expression { - public function __construct(private vec $elements) { - $this->precedence = 0; - $this->name = ''; - $this->type = TokenType::PAREN; - } - - <<__Override>> - public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { - - $result = vec[]; - - foreach ($this->elements as $expr) { - $result[] = $expr->evaluate($row, $conn); - } - return $result; - } - - <<__Override>> - public function isWellFormed(): bool { - return true; - } - - <<__Override>> - public function __debugInfo(): dict { - $elements = vec[]; - foreach ($this->elements as $elem) { - $elements[] = \var_dump($elem, true); - } - return dict['type' => 'row_expression', 'name' => $this->name, 'elements' => $elements]; - } + public function __construct(private vec $elements) { + $this->precedence = 0; + $this->name = ''; + $this->type = TokenType::PAREN; + } + + <<__Override>> + public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { + + $result = vec[]; + + foreach ($this->elements as $expr) { + $result[] = $expr->evaluate($row, $conn); + } + return $result; + } + + <<__Override>> + public function isWellFormed(): bool { + return true; + } + + <<__Override>> + public function __debugInfo(): dict { + $elements = vec[]; + foreach ($this->elements as $elem) { + $elements[] = \var_dump($elem, true); + } + return dict['type' => 'row_expression', 'name' => $this->name, 'elements' => $elements]; + } } diff --git a/src/Expressions/SubqueryExpression.php b/src/Expressions/SubqueryExpression.php index 065e307..e016edc 100644 --- a/src/Expressions/SubqueryExpression.php +++ b/src/Expressions/SubqueryExpression.php @@ -8,27 +8,27 @@ */ final class SubqueryExpression extends Expression { - public function __construct(private SelectQuery $query, public string $name) { - $this->precedence = 0; - $this->type = TokenType::CLAUSE; - } + public function __construct(private SelectQuery $query, public string $name) { + $this->precedence = 0; + $this->type = TokenType::CLAUSE; + } - <<__Override>> - /** - * Evaluate the subquery, passing the current row from the outer query along - * for correlated subqueries (not currently supported) - */ - public function evaluateImpl(row $row, AsyncMysqlConnection $conn): dataset { - return $this->query->execute($conn, $row); - } + <<__Override>> + /** + * Evaluate the subquery, passing the current row from the outer query along + * for correlated subqueries (not currently supported) + */ + public function evaluateImpl(row $row, AsyncMysqlConnection $conn): dataset { + return $this->query->execute($conn, $row); + } - <<__Override>> - public function isWellFormed(): bool { - return true; - } + <<__Override>> + public function isWellFormed(): bool { + return true; + } - <<__Override>> - public function __debugInfo(): dict { - return dict['type' => 'subquery', 'query' => $this->query, 'name' => $this->name]; - } + <<__Override>> + public function __debugInfo(): dict { + return dict['type' => 'subquery', 'query' => $this->query, 'name' => $this->name]; + } } diff --git a/src/Expressions/UnaryExpression.php b/src/Expressions/UnaryExpression.php index e57b373..9fb57e4 100644 --- a/src/Expressions/UnaryExpression.php +++ b/src/Expressions/UnaryExpression.php @@ -11,58 +11,58 @@ final class UnaryExpression extends Expression { - private ?Expression $subject = null; + private ?Expression $subject = null; - public function __construct(public ?Operator $operator) { - $this->type = TokenType::OPERATOR; - $this->precedence = 14; - $this->name = operatorn_to_string($operator); - } + public function __construct(public ?Operator $operator) { + $this->type = TokenType::OPERATOR; + $this->precedence = 14; + $this->name = operatorn_to_string($operator); + } - <<__Override>> - public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { - if ($this->subject === null) { - throw new SQLFakeRuntimeException('Attempted to evaluate unary operation with no operand'); - } - $val = $this->subject->evaluate($row, $conn); + <<__Override>> + public function evaluateImpl(row $row, AsyncMysqlConnection $conn): mixed { + if ($this->subject === null) { + throw new SQLFakeRuntimeException('Attempted to evaluate unary operation with no operand'); + } + $val = $this->subject->evaluate($row, $conn); - $op = $this->operator; - invariant($op is nonnull, 'This case was not considered. The operator is null.'); - switch ($op) { - case Operator::UNARY_MINUS: - return -1 * (float)$val; - case Operator::UNARY_PLUS: - return (float)$val; - case Operator::TILDE: - return ~(int)$val; - default: - throw new SQLFakeRuntimeException("Unimplemented unary operand {$this->name}"); - } + $op = $this->operator; + invariant($op is nonnull, 'This case was not considered. The operator is null.'); + switch ($op) { + case Operator::UNARY_MINUS: + return -1 * (float)$val; + case Operator::UNARY_PLUS: + return (float)$val; + case Operator::TILDE: + return ~(int)$val; + default: + throw new SQLFakeRuntimeException("Unimplemented unary operand {$this->name}"); + } - return $val; - } + return $val; + } - <<__Override>> - public function setNextChild(Expression $expr, bool $overwrite = false): void { - if ($this->subject is nonnull && !$overwrite) { - throw new SQLFakeParseException('Unexpected expression after unary operand'); - } - $this->subject = $expr; - } + <<__Override>> + public function setNextChild(Expression $expr, bool $overwrite = false): void { + if ($this->subject is nonnull && !$overwrite) { + throw new SQLFakeParseException('Unexpected expression after unary operand'); + } + $this->subject = $expr; + } - <<__Override>> - public function isWellFormed(): bool { - return $this->subject is nonnull; - } + <<__Override>> + public function isWellFormed(): bool { + return $this->subject is nonnull; + } - <<__Override>> - public function __debugInfo(): dict { - $subject = $this->subject ? \var_dump($this->subject, true) : dict[]; - return dict[ - 'type' => 'unary', - 'operator' => $this->operator, - 'name' => $this->name, - 'subject' => $subject, - ]; - } + <<__Override>> + public function __debugInfo(): dict { + $subject = $this->subject ? \var_dump($this->subject, true) : dict[]; + return dict[ + 'type' => 'unary', + 'operator' => $this->operator, + 'name' => $this->name, + 'subject' => $subject, + ]; + } } diff --git a/src/Formatting/Printf/Exception.php b/src/Formatting/Printf/Exception.php index 71c719b..e6b320b 100644 --- a/src/Formatting/Printf/Exception.php +++ b/src/Formatting/Printf/Exception.php @@ -7,35 +7,35 @@ abstract class PrintfException extends \RuntimeException {} final class InvalidFormatSpecifierException extends PrintfException { - public static function create( - string $format, - string $parsed, - int $char_idx, - Traversable $recognized_specifiers, - ): this { - return new static(Str\format( - 'Invalid format specifier, got %s at %d, supports (%s). Format: %s', - \var_export($parsed, true), - $char_idx, - Str\join(Vec\map($recognized_specifiers, $str ==> '%'.$str), ', '), - $format, - )); - } + public static function create( + string $format, + string $parsed, + int $char_idx, + Traversable $recognized_specifiers, + ): this { + return new static(Str\format( + 'Invalid format specifier, got %s at %d, supports (%s). Format: %s', + \var_export($parsed, true), + $char_idx, + Str\join(Vec\map($recognized_specifiers, $str ==> '%'.$str), ', '), + $format, + )); + } } final class TooFewArgumentsException extends PrintfException { - public static function create(string $format, int $arguments_provided): this { - return new static(Str\format('Too few arguments provided, got %d. Format: %s', $arguments_provided, $format)); - } + public static function create(string $format, int $arguments_provided): this { + return new static(Str\format('Too few arguments provided, got %d. Format: %s', $arguments_provided, $format)); + } } final class TooManyArgumentsException extends PrintfException { - public static function create(string $format, int $arguments_provided, int $arguments_consumed): this { - return new static(Str\format( - 'Too many arguments provided, got %d, expected %d. Format: %s', - $arguments_provided, - $arguments_consumed, - $format, - )); - } + public static function create(string $format, int $arguments_provided, int $arguments_consumed): this { + return new static(Str\format( + 'Too many arguments provided, got %d, expected %d. Format: %s', + $arguments_provided, + $arguments_consumed, + $format, + )); + } } diff --git a/src/Formatting/Printf/Formatter.php b/src/Formatting/Printf/Formatter.php index 6a55267..0b7d285 100644 --- a/src/Formatting/Printf/Formatter.php +++ b/src/Formatting/Printf/Formatter.php @@ -3,5 +3,5 @@ namespace Slack\SQLFake\Printf; interface Formatter { - public function format(string $format, vec $args): string; + public function format(string $format, vec $args): string; } diff --git a/src/Formatting/Printf/Printf.php b/src/Formatting/Printf/Printf.php index c4fbafc..e78c852 100644 --- a/src/Formatting/Printf/Printf.php +++ b/src/Formatting/Printf/Printf.php @@ -5,114 +5,114 @@ use namespace HH\Lib\{C, Math, Str, Vec}; final class Printf implements Formatter { - const type TFormatter = (function( - /*tuple*/(mixed /*$arg*/, int /*$modifier_index*/, string /*$format*/, string /*$modifier_text*/), - ): string); - private int $longestSpecifierLength; - public function __construct( - private dict bool, 'func' => this::TFormatter)> $specifiers, - ) { - $this->longestSpecifierLength = Vec\keys($specifiers) |> Vec\map($$, Str\length<>) |> Math\max($$) ?? 0; - } - - public function addSpecifier(string $key, shape('needs_arg' => bool, 'func' => this::TFormatter) $specifier): void { - $this->specifiers[$key] = $specifier; - $this->longestSpecifierLength = Math\maxva($this->longestSpecifierLength, Str\length($key)); - } - - public function format(string $format, vec $args): string { - // An optimization, still correct if removed. - if (!Str\contains($format, '%') && C\is_empty($args)) { - return $format; - } - - $modifiers = $this->identifyModifiers($format); - $args = $this->prepareArguments($format, $modifiers, $args); - return $this->formatImpl($format, $modifiers, $args); - } - - private function identifyModifiers(string $format): vec int, 'length' => int, 'text' => string)> { - $length = Str\length($format); - $max_length = $this->longestSpecifierLength; - $specifiers = $this->specifiers; - - $after_percent = false; - $buf = ''; - $out = vec[]; - - for ($i = 0; $i < $length; ++$i) { - $char = $format[$i]; - - if ($after_percent) { - $buf .= $char; - $buf_len = Str\length($buf); - if (C\contains_key($specifiers, $buf)) { - // The `char_idx` is the index of the `%` and the `+ 1` is required, since the `%` is not in `$buf`. - $out[] = shape('char_idx' => $i - $buf_len, 'length' => $buf_len + 1, 'text' => $buf); - $buf = ''; - $after_percent = false; - } else if ($buf_len >= $max_length) { - throw InvalidFormatSpecifierException::create($format, $buf, $i - $max_length, Vec\keys($specifiers)); - } - } else if ($char === '%') { - $after_percent = true; - } - } - - if ($after_percent) { - throw InvalidFormatSpecifierException::create($format, $buf, $i - $max_length, Vec\keys($specifiers)); - } - - return $out; - } - - private function prepareArguments( - string $format, - vec string, ...)> $modifiers, - vec $args, - ): vec { - $specifiers = $this->specifiers; - - $out = vec[]; - $arg_i = 0; - $arg_count = C\count($args); - - foreach ($modifiers as $mod_i => $m) { - $modifier_text = $m['text']; - $specifier = $specifiers[$modifier_text]; - if ($specifier['needs_arg']) { - if ($arg_i === $arg_count) { - throw TooFewArgumentsException::create($format, $arg_count); - } - $out[] = $specifier['func'](tuple($args[$arg_i], $mod_i, $format, $modifier_text)); - ++$arg_i; - } else { - $out[] = $specifier['func'](tuple(null, $mod_i, $format, $modifier_text)); - } - } - - if ($arg_i !== $arg_count) { - throw TooManyArgumentsException::create($format, $arg_count, $arg_i); - } - - return $out; - } - - private function formatImpl( - string $format, - vec int, 'length' => int, ...)> $modifiers, - vec $args, - ): string { - invariant(C\count($modifiers) === C\count($args), 'This invariant is upheld by prepareArguments'); - $out = ''; - $start_slice = 0; - - foreach ($modifiers as $i => $m) { - $out .= Str\slice($format, $start_slice, $m['char_idx'] - $start_slice); - $start_slice = $m['char_idx'] + $m['length']; - $out .= $args[$i]; - } - - return $out.Str\slice($format, $start_slice); - } + const type TFormatter = (function( + /*tuple*/(mixed /*$arg*/, int /*$modifier_index*/, string /*$format*/, string /*$modifier_text*/), + ): string); + private int $longestSpecifierLength; + public function __construct( + private dict bool, 'func' => this::TFormatter)> $specifiers, + ) { + $this->longestSpecifierLength = Vec\keys($specifiers) |> Vec\map($$, Str\length<>) |> Math\max($$) ?? 0; + } + + public function addSpecifier(string $key, shape('needs_arg' => bool, 'func' => this::TFormatter) $specifier): void { + $this->specifiers[$key] = $specifier; + $this->longestSpecifierLength = Math\maxva($this->longestSpecifierLength, Str\length($key)); + } + + public function format(string $format, vec $args): string { + // An optimization, still correct if removed. + if (!Str\contains($format, '%') && C\is_empty($args)) { + return $format; + } + + $modifiers = $this->identifyModifiers($format); + $args = $this->prepareArguments($format, $modifiers, $args); + return $this->formatImpl($format, $modifiers, $args); + } + + private function identifyModifiers(string $format): vec int, 'length' => int, 'text' => string)> { + $length = Str\length($format); + $max_length = $this->longestSpecifierLength; + $specifiers = $this->specifiers; + + $after_percent = false; + $buf = ''; + $out = vec[]; + + for ($i = 0; $i < $length; ++$i) { + $char = $format[$i]; + + if ($after_percent) { + $buf .= $char; + $buf_len = Str\length($buf); + if (C\contains_key($specifiers, $buf)) { + // The `char_idx` is the index of the `%` and the `+ 1` is required, since the `%` is not in `$buf`. + $out[] = shape('char_idx' => $i - $buf_len, 'length' => $buf_len + 1, 'text' => $buf); + $buf = ''; + $after_percent = false; + } else if ($buf_len >= $max_length) { + throw InvalidFormatSpecifierException::create($format, $buf, $i - $max_length, Vec\keys($specifiers)); + } + } else if ($char === '%') { + $after_percent = true; + } + } + + if ($after_percent) { + throw InvalidFormatSpecifierException::create($format, $buf, $i - $max_length, Vec\keys($specifiers)); + } + + return $out; + } + + private function prepareArguments( + string $format, + vec string, ...)> $modifiers, + vec $args, + ): vec { + $specifiers = $this->specifiers; + + $out = vec[]; + $arg_i = 0; + $arg_count = C\count($args); + + foreach ($modifiers as $mod_i => $m) { + $modifier_text = $m['text']; + $specifier = $specifiers[$modifier_text]; + if ($specifier['needs_arg']) { + if ($arg_i === $arg_count) { + throw TooFewArgumentsException::create($format, $arg_count); + } + $out[] = $specifier['func'](tuple($args[$arg_i], $mod_i, $format, $modifier_text)); + ++$arg_i; + } else { + $out[] = $specifier['func'](tuple(null, $mod_i, $format, $modifier_text)); + } + } + + if ($arg_i !== $arg_count) { + throw TooManyArgumentsException::create($format, $arg_count, $arg_i); + } + + return $out; + } + + private function formatImpl( + string $format, + vec int, 'length' => int, ...)> $modifiers, + vec $args, + ): string { + invariant(C\count($modifiers) === C\count($args), 'This invariant is upheld by prepareArguments'); + $out = ''; + $start_slice = 0; + + foreach ($modifiers as $i => $m) { + $out .= Str\slice($format, $start_slice, $m['char_idx'] - $start_slice); + $start_slice = $m['char_idx'] + $m['length']; + $out .= $args[$i]; + } + + return $out.Str\slice($format, $start_slice); + } } diff --git a/src/Formatting/QueryStringifier.php b/src/Formatting/QueryStringifier.php index 5b62d88..736b540 100644 --- a/src/Formatting/QueryStringifier.php +++ b/src/Formatting/QueryStringifier.php @@ -7,218 +7,218 @@ use type HH\Lib\SQL\Query; final class QueryStringifier { - public function __construct( - private Printf\Formatter $queryAsyncFormatter, - private Printf\Formatter $queryfFormatter, - ) {} - - public function formatString(string $format, vec $args): string { - static::assertNoDangerousCharacters($format); - try { - return $this->queryfFormatter->format($format, $args); - } catch (Printf\PrintfException $e) { - throw new SQLFakeParseException('See previous exception', $e->getCode(), $e); - } - } - - public function formatQuery(Query $query): string { - list($format, $args) = static::reflectQueryContents($query); - static::assertNoDangerousCharacters($format); - try { - return $this->queryAsyncFormatter->format($format, $args); - } catch (Printf\PrintfException $e) { - throw new SQLFakeParseException('See previous exception', $e->getCode(), $e); - } - } - - public static function createForTypesafeHack(): this { - return new static(static::typesafeQueryAsyncFormatter(), static::typesafeQueryfFormatter()); - } - - public static function typesafeQueryAsyncFormatter(): Printf\Formatter { - $with_arg = (Printf\Printf::TFormatter $func) ==> shape('needs_arg' => true, 'func' => $func); - $no_arg = (Printf\Printf::TFormatter $func) ==> shape( - 'needs_arg' => false, - 'func' => ((mixed, int, string, string) $arg) ==> { - invariant($arg[0] is null, 'We should never consume a real argument here'); - return $func($arg); - }, - ); - - $query_format = new _Private\QueryFormat(); - $equals_format = new _Private\EqualsScalarFormat(); - $list_format = new _Private\ListFormat(); - - $query_async = new Printf\Printf(dict[ - '%' => $no_arg($_meh ==> $query_format->format_0x25()), - 'd' => $with_arg($arg ==> $query_format->format_d(static::guardType('?int', ...$arg))), - 'f' => $with_arg($arg ==> $query_format->format_f(static::guardType('?float', ...$arg))), - 's' => $with_arg($arg ==> $query_format->format_s(static::guardType('?string', ...$arg))), - 'C' => $with_arg($arg ==> $query_format->format_upcase_c(static::guardType('string', ...$arg))), - 'K' => $with_arg($arg ==> $query_format->format_upcase_k(static::guardType('string', ...$arg))), - 'Ld' => $with_arg($arg ==> $list_format->format_d(static::guardVecOfType('int', ...$arg))), - 'Lf' => $with_arg($arg ==> $list_format->format_f(static::guardVecOfType('float', ...$arg))), - 'Ls' => $with_arg($arg ==> $list_format->format_s(static::guardVecOfType('string', ...$arg))), - 'LC' => $with_arg($arg ==> $list_format->format_upcase_c(static::guardVecOfType('string', ...$arg))), - 'T' => $with_arg($arg ==> $query_format->format_upcase_t(static::guardType('string', ...$arg))), - '=d' => $with_arg($arg ==> $equals_format->format_d(static::guardType('?int', ...$arg))), - '=f' => $with_arg($arg ==> $equals_format->format_f(static::guardType('?float', ...$arg))), - '=s' => $with_arg($arg ==> $equals_format->format_s(static::guardType('?string', ...$arg))), - ]); - - $self = new static($query_async, new Printf\Printf(dict[])); - $query_async->addSpecifier( - 'Q', - $with_arg($arg ==> $self->formatQuery(static::guardType(Query::class, ...$arg))), - ); - return $query_async; - } - - public static function typesafeQueryfFormatter(): Printf\Formatter { - $with_arg = (Printf\Printf::TFormatter $func) ==> shape('needs_arg' => true, 'func' => $func); - $no_arg = (Printf\Printf::TFormatter $func) ==> shape( - 'needs_arg' => false, - 'func' => ((mixed, int, string, string) $arg) ==> { - invariant($arg[0] is null, 'We should never consume a real argument here'); - return $func($arg); - }, - ); - - $sql_format = new _Private\SQLFormatter(new _Private\QueryFormat()); - $sql_list_format = new _Private\SQLListFormatter(new _Private\ListFormat()); - $sql_equals_formatter = new _Private\SQLEqualsScalarFormatter(new _Private\EqualsScalarFormat()); - - $queryf = new Printf\Printf(dict[ - '%' => $no_arg($_meh ==> $sql_format->format_0x25()), - 'd' => $with_arg($arg ==> $sql_format->format_d(static::guardType('?int', ...$arg))), - 'f' => $with_arg($arg ==> $sql_format->format_f(static::guardType('?float', ...$arg))), - 's' => $with_arg($arg ==> $sql_format->format_s(static::guardType('?string', ...$arg))), - 'C' => $with_arg($arg ==> $sql_format->format_upcase_c(static::guardType('string', ...$arg))), - 'Ld' => $with_arg($arg ==> $sql_list_format->format_d(static::guardVectorOfType('int', ...$arg))), - 'Lf' => $with_arg($arg ==> $sql_list_format->format_f(static::guardVectorOfType('float', ...$arg))), - 'Ls' => $with_arg($arg ==> $sql_list_format->format_s(static::guardVectorOfType('string', ...$arg))), - 'LC' => - $with_arg($arg ==> $sql_list_format->format_upcase_c(static::guardVectorOfType('string', ...$arg))), - 'T' => $with_arg($arg ==> $sql_format->format_upcase_t(static::guardType('string', ...$arg))), - '=d' => $with_arg($arg ==> $sql_equals_formatter->format_d(static::guardType('?int', ...$arg))), - '=f' => $with_arg($arg ==> $sql_equals_formatter->format_f(static::guardType('?float', ...$arg))), - '=s' => $with_arg($arg ==> $sql_equals_formatter->format_s(static::guardType('?string', ...$arg))), - ]); - - return $queryf; - } - - private static function assertNoDangerousCharacters(string $format): void { - // Reimplementing this check from squangle - // https://github.com/facebook/squangle/blob/16a37e240583cbd1278d1dfe359c7d53229ea3e0/squangle/mysql_client/Query.cpp#L474 - foreach (vec[';', "'", '"', '`'] as $dangerous_char) { - $idx = Str\search($format, $dangerous_char); - if ($idx) { - throw new SQLFakeParseException( - Str\format('Saw dangerous character %s in SQL query as index %d. Query: %s', $dangerous_char, $idx, $format), - ); - } - } - } - - private static function getType(mixed $var): string { - return \is_object($var) ? \get_class($var) : \gettype($var); - } - - private static function guardType<<<__Enforceable>> reify T>( - string $type_text, - mixed $arg, - int $modifier_index, - string $format, - string $modifier_text, - ): T { - if ($arg is T) { - return $arg; - } - throw new SQLFakeParseException(Str\format( - 'Expected %s for specifier %%%s at index %d, got %s. Query: %s', - $type_text, - $modifier_text, - $modifier_index, - static::getType($arg), - $format, - )); - } - - private static function guardVecOfType<<<__Enforceable>> reify T>( - string $type_text, - mixed $arg, - int $modifier_index, - string $format, - string $modifier_text, - ): vec { - if ($arg is vec<_>) { - $out = vec[]; - foreach ($arg as $sub_arg) { - if (!$sub_arg is T) { - throw new SQLFakeParseException(Str\format( - 'Expected all elements of vec to be %s for specifier %%%s at index %d, got %s. Query: %s', - $type_text, - $modifier_text, - $modifier_index, - static::getType($sub_arg), - $format, - )); - } - $out[] = $sub_arg; - } - return $out; - } - throw new SQLFakeParseException(Str\format( - 'Expected vec<%s> for specifier %%%s at index %d, got %s. Query: %s', - $type_text, - $modifier_text, - $modifier_index, - static::getType($arg), - $format, - )); - } - - private static function guardVectorOfType<<<__Enforceable>> reify T>( - string $type_text, - mixed $arg, - int $modifier_index, - string $format, - string $modifier_text, - ): ImmVector { - if ($arg is ConstVector<_>) { - $out = Vector {}; - foreach ($arg as $sub_arg) { - if (!$sub_arg is T) { - throw new SQLFakeParseException(Str\format( - 'Expected all elements of ConstVector to be %s for specifier %%%s at index %d, got %s. Query: %s', - $type_text, - $modifier_text, - $modifier_index, - static::getType($sub_arg), - $format, - )); - } - $out[] = $sub_arg; - } - return $out->immutable(); - } - throw new SQLFakeParseException(Str\format( - 'Expected ConstVector<%s> for specifier %%%s at index %d, got %s. Query: %s', - $type_text, - $modifier_text, - $modifier_index, - static::getType($arg), - $format, - )); - } - - private static function reflectQueryContents(Query $query): (string, vec) { - $ro = new \ReflectionObject($query); - $format = $ro->getProperty('format'); - $format->setAccessible(true); - $args = $ro->getProperty('args'); - $args->setAccessible(true); - return tuple($format->getValue($query), vec($args->getValue($query))); - } + public function __construct( + private Printf\Formatter $queryAsyncFormatter, + private Printf\Formatter $queryfFormatter, + ) {} + + public function formatString(string $format, vec $args): string { + static::assertNoDangerousCharacters($format); + try { + return $this->queryfFormatter->format($format, $args); + } catch (Printf\PrintfException $e) { + throw new SQLFakeParseException('See previous exception', $e->getCode(), $e); + } + } + + public function formatQuery(Query $query): string { + list($format, $args) = static::reflectQueryContents($query); + static::assertNoDangerousCharacters($format); + try { + return $this->queryAsyncFormatter->format($format, $args); + } catch (Printf\PrintfException $e) { + throw new SQLFakeParseException('See previous exception', $e->getCode(), $e); + } + } + + public static function createForTypesafeHack(): this { + return new static(static::typesafeQueryAsyncFormatter(), static::typesafeQueryfFormatter()); + } + + public static function typesafeQueryAsyncFormatter(): Printf\Formatter { + $with_arg = (Printf\Printf::TFormatter $func) ==> shape('needs_arg' => true, 'func' => $func); + $no_arg = (Printf\Printf::TFormatter $func) ==> shape( + 'needs_arg' => false, + 'func' => ((mixed, int, string, string) $arg) ==> { + invariant($arg[0] is null, 'We should never consume a real argument here'); + return $func($arg); + }, + ); + + $query_format = new _Private\QueryFormat(); + $equals_format = new _Private\EqualsScalarFormat(); + $list_format = new _Private\ListFormat(); + + $query_async = new Printf\Printf(dict[ + '%' => $no_arg($_meh ==> $query_format->format_0x25()), + 'd' => $with_arg($arg ==> $query_format->format_d(static::guardType('?int', ...$arg))), + 'f' => $with_arg($arg ==> $query_format->format_f(static::guardType('?float', ...$arg))), + 's' => $with_arg($arg ==> $query_format->format_s(static::guardType('?string', ...$arg))), + 'C' => $with_arg($arg ==> $query_format->format_upcase_c(static::guardType('string', ...$arg))), + 'K' => $with_arg($arg ==> $query_format->format_upcase_k(static::guardType('string', ...$arg))), + 'Ld' => $with_arg($arg ==> $list_format->format_d(static::guardVecOfType('int', ...$arg))), + 'Lf' => $with_arg($arg ==> $list_format->format_f(static::guardVecOfType('float', ...$arg))), + 'Ls' => $with_arg($arg ==> $list_format->format_s(static::guardVecOfType('string', ...$arg))), + 'LC' => $with_arg($arg ==> $list_format->format_upcase_c(static::guardVecOfType('string', ...$arg))), + 'T' => $with_arg($arg ==> $query_format->format_upcase_t(static::guardType('string', ...$arg))), + '=d' => $with_arg($arg ==> $equals_format->format_d(static::guardType('?int', ...$arg))), + '=f' => $with_arg($arg ==> $equals_format->format_f(static::guardType('?float', ...$arg))), + '=s' => $with_arg($arg ==> $equals_format->format_s(static::guardType('?string', ...$arg))), + ]); + + $self = new static($query_async, new Printf\Printf(dict[])); + $query_async->addSpecifier( + 'Q', + $with_arg($arg ==> $self->formatQuery(static::guardType(Query::class, ...$arg))), + ); + return $query_async; + } + + public static function typesafeQueryfFormatter(): Printf\Formatter { + $with_arg = (Printf\Printf::TFormatter $func) ==> shape('needs_arg' => true, 'func' => $func); + $no_arg = (Printf\Printf::TFormatter $func) ==> shape( + 'needs_arg' => false, + 'func' => ((mixed, int, string, string) $arg) ==> { + invariant($arg[0] is null, 'We should never consume a real argument here'); + return $func($arg); + }, + ); + + $sql_format = new _Private\SQLFormatter(new _Private\QueryFormat()); + $sql_list_format = new _Private\SQLListFormatter(new _Private\ListFormat()); + $sql_equals_formatter = new _Private\SQLEqualsScalarFormatter(new _Private\EqualsScalarFormat()); + + $queryf = new Printf\Printf(dict[ + '%' => $no_arg($_meh ==> $sql_format->format_0x25()), + 'd' => $with_arg($arg ==> $sql_format->format_d(static::guardType('?int', ...$arg))), + 'f' => $with_arg($arg ==> $sql_format->format_f(static::guardType('?float', ...$arg))), + 's' => $with_arg($arg ==> $sql_format->format_s(static::guardType('?string', ...$arg))), + 'C' => $with_arg($arg ==> $sql_format->format_upcase_c(static::guardType('string', ...$arg))), + 'Ld' => $with_arg($arg ==> $sql_list_format->format_d(static::guardVectorOfType('int', ...$arg))), + 'Lf' => $with_arg($arg ==> $sql_list_format->format_f(static::guardVectorOfType('float', ...$arg))), + 'Ls' => $with_arg($arg ==> $sql_list_format->format_s(static::guardVectorOfType('string', ...$arg))), + 'LC' => + $with_arg($arg ==> $sql_list_format->format_upcase_c(static::guardVectorOfType('string', ...$arg))), + 'T' => $with_arg($arg ==> $sql_format->format_upcase_t(static::guardType('string', ...$arg))), + '=d' => $with_arg($arg ==> $sql_equals_formatter->format_d(static::guardType('?int', ...$arg))), + '=f' => $with_arg($arg ==> $sql_equals_formatter->format_f(static::guardType('?float', ...$arg))), + '=s' => $with_arg($arg ==> $sql_equals_formatter->format_s(static::guardType('?string', ...$arg))), + ]); + + return $queryf; + } + + private static function assertNoDangerousCharacters(string $format): void { + // Reimplementing this check from squangle + // https://github.com/facebook/squangle/blob/16a37e240583cbd1278d1dfe359c7d53229ea3e0/squangle/mysql_client/Query.cpp#L474 + foreach (vec[';', "'", '"', '`'] as $dangerous_char) { + $idx = Str\search($format, $dangerous_char); + if ($idx) { + throw new SQLFakeParseException( + Str\format('Saw dangerous character %s in SQL query as index %d. Query: %s', $dangerous_char, $idx, $format), + ); + } + } + } + + private static function getType(mixed $var): string { + return \is_object($var) ? \get_class($var) : \gettype($var); + } + + private static function guardType<<<__Enforceable>> reify T>( + string $type_text, + mixed $arg, + int $modifier_index, + string $format, + string $modifier_text, + ): T { + if ($arg is T) { + return $arg; + } + throw new SQLFakeParseException(Str\format( + 'Expected %s for specifier %%%s at index %d, got %s. Query: %s', + $type_text, + $modifier_text, + $modifier_index, + static::getType($arg), + $format, + )); + } + + private static function guardVecOfType<<<__Enforceable>> reify T>( + string $type_text, + mixed $arg, + int $modifier_index, + string $format, + string $modifier_text, + ): vec { + if ($arg is vec<_>) { + $out = vec[]; + foreach ($arg as $sub_arg) { + if (!$sub_arg is T) { + throw new SQLFakeParseException(Str\format( + 'Expected all elements of vec to be %s for specifier %%%s at index %d, got %s. Query: %s', + $type_text, + $modifier_text, + $modifier_index, + static::getType($sub_arg), + $format, + )); + } + $out[] = $sub_arg; + } + return $out; + } + throw new SQLFakeParseException(Str\format( + 'Expected vec<%s> for specifier %%%s at index %d, got %s. Query: %s', + $type_text, + $modifier_text, + $modifier_index, + static::getType($arg), + $format, + )); + } + + private static function guardVectorOfType<<<__Enforceable>> reify T>( + string $type_text, + mixed $arg, + int $modifier_index, + string $format, + string $modifier_text, + ): ImmVector { + if ($arg is ConstVector<_>) { + $out = Vector {}; + foreach ($arg as $sub_arg) { + if (!$sub_arg is T) { + throw new SQLFakeParseException(Str\format( + 'Expected all elements of ConstVector to be %s for specifier %%%s at index %d, got %s. Query: %s', + $type_text, + $modifier_text, + $modifier_index, + static::getType($sub_arg), + $format, + )); + } + $out[] = $sub_arg; + } + return $out->immutable(); + } + throw new SQLFakeParseException(Str\format( + 'Expected ConstVector<%s> for specifier %%%s at index %d, got %s. Query: %s', + $type_text, + $modifier_text, + $modifier_index, + static::getType($arg), + $format, + )); + } + + private static function reflectQueryContents(Query $query): (string, vec) { + $ro = new \ReflectionObject($query); + $format = $ro->getProperty('format'); + $format->setAccessible(true); + $args = $ro->getProperty('args'); + $args->setAccessible(true); + return tuple($format->getValue($query), vec($args->getValue($query))); + } } diff --git a/src/Formatting/implementations_of_query_format.php b/src/Formatting/implementations_of_query_format.php index 95e2101..0beb4d1 100644 --- a/src/Formatting/implementations_of_query_format.php +++ b/src/Formatting/implementations_of_query_format.php @@ -7,19 +7,19 @@ use namespace HH\Lib\{SQL, Str, Vec}; function comment_string(string $comment): string { - return '/*'.Str\replace_every($comment, dict['/*' => '/ * ', '*/' => '* / ']).'*/'; + return '/*'.Str\replace_every($comment, dict['/*' => '/ * ', '*/' => '* / ']).'*/'; } function escape_string(string $string): string { - return '"'.\mysql_escape_string($string).'"'; + return '"'.\mysql_escape_string($string).'"'; } function float_string(float $float): string { - return Str\trim_right(Str\format_number($float, 14, '.', ''), '0'); + return Str\trim_right(Str\format_number($float, 14, '.', ''), '0'); } function identifier_string(string $identifier): string { - return '`'.Str\replace($identifier, '`', '``').'`'; + return '`'.Str\replace($identifier, '`', '``').'`'; } // It is a really nice idea to "implement" the hhi from HH\Lib\SQL. @@ -30,64 +30,64 @@ function identifier_string(string $identifier): string { // fatal loudly. I'll file an issue against hhvm later today. final class EqualsScalarFormat /* implements SQL\ScalarFormat */ { - public function format_d(?int $int): string { - return $int is null ? ' IS NULL' : ' = '.$int; - } - public function format_f(?float $float): string { - return $float is null ? ' IS NULL' : ' = '.float_string($float); - } - public function format_s(?string $string): string { - return $string is null ? ' IS NULL' : ' = '.escape_string($string); - } + public function format_d(?int $int): string { + return $int is null ? ' IS NULL' : ' = '.$int; + } + public function format_f(?float $float): string { + return $float is null ? ' IS NULL' : ' = '.float_string($float); + } + public function format_s(?string $string): string { + return $string is null ? ' IS NULL' : ' = '.escape_string($string); + } } final class ListFormat /* implements SQL\ListFormat */ { - public function format_upcase_c(vec $columns): string { - return Vec\map($columns, identifier_string<>) |> Str\join($$, ', '); - } - public function format_d(vec $ints): string { - return Str\join($ints, ', '); - } - public function format_f(vec $floats): string { - return Vec\map($floats, float_string<>) |> Str\join($$, ', '); - } - public function format_s(vec $strings): string { - return Vec\map($strings, escape_string<>) |> Str\join($$, ', '); - } + public function format_upcase_c(vec $columns): string { + return Vec\map($columns, identifier_string<>) |> Str\join($$, ', '); + } + public function format_d(vec $ints): string { + return Str\join($ints, ', '); + } + public function format_f(vec $floats): string { + return Vec\map($floats, float_string<>) |> Str\join($$, ', '); + } + public function format_s(vec $strings): string { + return Vec\map($strings, escape_string<>) |> Str\join($$, ', '); + } } final class QueryFormat /* implements SQL\QueryFormat */ { - public function format_0x25(): string { - return '%'; - } - public function format_d(?int $int): string { - return $int is null ? 'NULL' : (string)$int; - } - public function format_f(?float $float): string { - return $float is null ? 'NULL' : float_string($float); - } - public function format_s(?string $string): string { - return $string is null ? 'NULL' : escape_string($string); - } - public function format_upcase_c(string $column): string { - return identifier_string($column); - } - public function format_upcase_k(string $comment): string { - return comment_string($comment); - } - public function format_upcase_l(): ListFormat { - return new ListFormat(); - } - public function format_upcase_q(SQL\Query $_query): string { - invariant_violation( - 'There are multiple valid implementation for %s, so you must implement it externally.', - __METHOD__, - ); - } - public function format_upcase_t(string $table): string { - return identifier_string($table); - } - public function format_0x3d(): EqualsScalarFormat { - return new EqualsScalarFormat(); - } + public function format_0x25(): string { + return '%'; + } + public function format_d(?int $int): string { + return $int is null ? 'NULL' : (string)$int; + } + public function format_f(?float $float): string { + return $float is null ? 'NULL' : float_string($float); + } + public function format_s(?string $string): string { + return $string is null ? 'NULL' : escape_string($string); + } + public function format_upcase_c(string $column): string { + return identifier_string($column); + } + public function format_upcase_k(string $comment): string { + return comment_string($comment); + } + public function format_upcase_l(): ListFormat { + return new ListFormat(); + } + public function format_upcase_q(SQL\Query $_query): string { + invariant_violation( + 'There are multiple valid implementation for %s, so you must implement it externally.', + __METHOD__, + ); + } + public function format_upcase_t(string $table): string { + return identifier_string($table); + } + public function format_0x3d(): EqualsScalarFormat { + return new EqualsScalarFormat(); + } } diff --git a/src/Formatting/implementations_of_queryf.php b/src/Formatting/implementations_of_queryf.php index 1377e4c..f9dbc8b 100644 --- a/src/Formatting/implementations_of_queryf.php +++ b/src/Formatting/implementations_of_queryf.php @@ -14,60 +14,60 @@ // fatal loudly. I'll file an issue against hhvm later today. final class SQLEqualsScalarFormatter /* implements \HH\SQLScalarFormatter */ { - public function __construct(private EqualsScalarFormat $scalarFormat) {} + public function __construct(private EqualsScalarFormat $scalarFormat) {} - public function format_d(?int $int): string { - return $this->scalarFormat->format_d($int); - } - public function format_f(?float $float): string { - return $this->scalarFormat->format_f($float); - } - public function format_s(?string $string): string { - return $this->scalarFormat->format_s($string); - } + public function format_d(?int $int): string { + return $this->scalarFormat->format_d($int); + } + public function format_f(?float $float): string { + return $this->scalarFormat->format_f($float); + } + public function format_s(?string $string): string { + return $this->scalarFormat->format_s($string); + } } final class SQLListFormatter /* implements \HH\SQLListFormatter */ { - public function __construct(private ListFormat $listFormat) {} + public function __construct(private ListFormat $listFormat) {} - public function format_upcase_c(ConstVector $columns): string { - return $this->listFormat->format_upcase_c(vec($columns)); - } - public function format_d(ConstVector $ints): string { - return $this->listFormat->format_d(vec($ints)); - } - public function format_f(ConstVector $floats): string { - return $this->listFormat->format_f(vec($floats)); - } - public function format_s(ConstVector $strings): string { - return $this->listFormat->format_s(vec($strings)); - } + public function format_upcase_c(ConstVector $columns): string { + return $this->listFormat->format_upcase_c(vec($columns)); + } + public function format_d(ConstVector $ints): string { + return $this->listFormat->format_d(vec($ints)); + } + public function format_f(ConstVector $floats): string { + return $this->listFormat->format_f(vec($floats)); + } + public function format_s(ConstVector $strings): string { + return $this->listFormat->format_s(vec($strings)); + } } final class SQLFormatter /* implements \HH\SQLFormatter */ { - public function __construct(private QueryFormat $queryFormat) {} - public function format_0x25(): string { - return '%'; - } - public function format_d(?int $int): string { - return $this->queryFormat->format_d($int); - } - public function format_f(?float $float): string { - return $this->queryFormat->format_f($float); - } - public function format_s(?string $string): string { - return $this->queryFormat->format_s($string); - } - public function format_upcase_t(string $table): string { - return $this->queryFormat->format_upcase_t($table); - } - public function format_upcase_c(string $column): string { - return $this->queryFormat->format_upcase_c($column); - } - public function format_upcase_l(): SQLListFormatter { - return new SQLListFormatter(new ListFormat()); - } - public function format_0x3d(): SQLEqualsScalarFormatter { - return new SQLEqualsScalarFormatter(new EqualsScalarFormat()); - } + public function __construct(private QueryFormat $queryFormat) {} + public function format_0x25(): string { + return '%'; + } + public function format_d(?int $int): string { + return $this->queryFormat->format_d($int); + } + public function format_f(?float $float): string { + return $this->queryFormat->format_f($float); + } + public function format_s(?string $string): string { + return $this->queryFormat->format_s($string); + } + public function format_upcase_t(string $table): string { + return $this->queryFormat->format_upcase_t($table); + } + public function format_upcase_c(string $column): string { + return $this->queryFormat->format_upcase_c($column); + } + public function format_upcase_l(): SQLListFormatter { + return new SQLListFormatter(new ListFormat()); + } + public function format_0x3d(): SQLEqualsScalarFormatter { + return new SQLEqualsScalarFormatter(new EqualsScalarFormat()); + } } diff --git a/src/Init.php b/src/Init.php index 5fef2aa..fe65ffe 100644 --- a/src/Init.php +++ b/src/Init.php @@ -17,24 +17,24 @@ * If strict mode is provided (recommended), SQLFake will throw an exception on any query referencing tables not in the schema. */ function init( - dict> $schema = dict[], - bool $strict_sql = false, - bool $strict_schema = false, + dict> $schema = dict[], + bool $strict_sql = false, + bool $strict_schema = false, ): void { - QueryContext::$schema = $schema; - QueryContext::$strictSQLMode = $strict_sql; - QueryContext::$strictSchemaMode = $strict_schema; + QueryContext::$schema = $schema; + QueryContext::$strictSQLMode = $strict_sql; + QueryContext::$strictSchemaMode = $strict_schema; } function add_server(string $hostname, server_config $config): void { - $server = Server::getOrCreate($hostname); - $server->setConfig($config); + $server = Server::getOrCreate($hostname); + $server->setConfig($config); } function snapshot(string $name): void { - Server::snapshot($name); + Server::snapshot($name); } function restore(string $name): void { - Server::restore($name); + Server::restore($name); } diff --git a/src/JSONPath/Exception.hack b/src/JSONPath/Exception.hack index 643fa1c..e2d430a 100644 --- a/src/JSONPath/Exception.hack +++ b/src/JSONPath/Exception.hack @@ -26,17 +26,17 @@ class JSONException extends \Exception {} final class InvalidJSONException extends JSONException {} final class DivergentJSONPathSetException extends JSONException {} final class InvalidJSONPathException extends JSONException { - private string $token; + private string $token; - /** - * Class constructor - * - * @param string $token token related to the JSONPath error - * - * @return void - */ - public function __construct(string $token) { - $this->token = $token; - parent::__construct("Error in JSONPath near '".$token."'", 0, null); - } + /** + * Class constructor + * + * @param string $token token related to the JSONPath error + * + * @return void + */ + public function __construct(string $token) { + $this->token = $token; + parent::__construct("Error in JSONPath near '".$token."'", 0, null); + } } diff --git a/src/JSONPath/JSONObject.hack b/src/JSONPath/JSONObject.hack index 436b5ff..7485eda 100755 --- a/src/JSONPath/JSONObject.hack +++ b/src/JSONPath/JSONObject.hack @@ -27,21 +27,21 @@ use namespace HH\Lib\{C, Regex, Str, Vec}; type ExplodedPathType = vec; type ObjectAtPath = shape( - 'object' => mixed, - 'path' => ExplodedPathType, + 'object' => mixed, + 'path' => ExplodedPathType, ); type MatchedObjectsResult = shape( - 'matched' => vec, - 'divergingPath' => bool, + 'matched' => vec, + 'divergingPath' => bool, ); type GetOptions = shape( - 'unwrap' => bool, // Directly return the object found (no outer vec[]) for non-branching path. default: false + 'unwrap' => bool, // Directly return the object found (no outer vec[]) for non-branching path. default: false ); class WrappedResult { - public function __construct(public T $value) {} + public function __construct(public T $value) {} } /** @@ -93,465 +93,465 @@ class WrappedResult { * */ class JSONObject { - // Tokens - const TOK_ROOT = '$'; - const TOK_SELECTOR_BEGIN = '['; - const TOK_SELECTOR_END = ']'; - const TOK_ALL = '*'; - const TOK_CHILD_ACCESS_BEGIN = '.'; - const TOK_DOUBLE_ASTERISK = '**'; - - private mixed $jsonObject; // This is any valid type in JSON (vec, dict, num, string, null) - - /** - * Class constructor. - * If $json is null the json object contained will be initialized empty. - */ - public function __construct(mixed $json = null) { - if ($json is string) { - $this->jsonObject = \json_decode($json, true, 512, \JSON_FB_HACK_ARRAYS); - if ($json !== 'null' && $this->jsonObject === null) { - throw new InvalidJSONException('string does not contain a valid JSON object.'); - } - } else if ($json is vec<_> || $json is dict<_, _> || \is_object($json)) { - // We encode & decode here to make sure we only have nested dicts/vecs & dict keys are string only - $this->jsonObject = \json_decode(\json_encode($json), true, 512, \JSON_FB_HACK_ARRAYS); - } else { - throw new InvalidJSONException('value does not encode a JSON object.'); - } - } - - /** - * Returns the value of the JSON object - */ - public function getValue(): mixed { - return $this->jsonObject; - } - - /** - * Returns the JSON representation of the JSON object - */ - public function getJSON(): string { - return \json_encode($this->jsonObject); - } - - /** - * Returns an vec containing objects that match the JsonPath. - * - * If smartGet was set to true when creating the instance and - * the JsonPath given does not branch, it will return the value - * instead of a vec of size 1. - */ - public function get(string $jsonPath, GetOptions $opts = shape('unwrap' => false)): ?WrappedResult { - $unwrap = $opts['unwrap'] ?? false; - - $jsonObject = $this->jsonObject; - $result = self::getMatching($jsonObject, $jsonPath); - - $matched = $result['matched']; - $divergingPath = $result['divergingPath']; - if (C\is_empty($matched)) { - return null; - } - - if ($unwrap && !$divergingPath) { - return new WrappedResult($matched[0]['object']); - } - - return new WrappedResult(Vec\map($matched, $o ==> $o['object'])); - } - - /** - * Replaces the element that results from the $jsonPath query - * to $value. This method disallows divering (wildcard) paths. - * It is a no-op if the given path doesn't already exist. - * - * This method returns a new JsonObject with the new value. - */ - public function replace(string $jsonPath, mixed $value): WrappedResult { - $result = self::getMatching($this->jsonObject, $jsonPath); - if ($result['divergingPath']) { - throw new DivergentJSONPathSetException('Cannot set a value using a wildcard JSON path'); - } - - $out = self::setPathsToValue( - shape('object' => $this->jsonObject, 'path' => vec[self::TOK_ROOT]), - Vec\map($result['matched'], $m ==> $m['path']), - $value, - ); - - return new WrappedResult(new JSONObject(\json_encode($out))); - } - - /** - * Returns the keys of the object located by the $jsonPath query. - * This method disallows divering (wildcard) paths. - * - * This method returns null (if object not located) or a vec[] of the keys. - */ - public function keys(string $jsonPath = '$'): ?WrappedResult> { - $result = self::getMatching($this->jsonObject, $jsonPath); - if ($result['divergingPath']) { - throw new DivergentJSONPathSetException('Cannot get keys using a wildcard JSON path'); - } - - $matches = $result['matched']; - if (C\is_empty($matches)) { - return null; - } - - $matched = $matches[0]['object']; - if (!($matched is dict<_, _>)) { - return null; - } - - $keys = vec[]; - foreach ($matched as $k => $_v) { - invariant($k is string, 'dict cannot have non string key in JSON'); - $keys[] = $k; - } - - return new WrappedResult($keys); - } - - /** - * Returns the length of the value located by the $jsonPath query. - * This method disallows divering (wildcard) paths. - * - * This method returns null (if object not located) or a vec[] of the keys. - */ - public function length(string $jsonPath = '$'): ?WrappedResult { - $result = self::getMatching($this->jsonObject, $jsonPath); - if ($result['divergingPath']) { - throw new DivergentJSONPathSetException('Cannot get length using a wildcard JSON path'); - } - - $matches = $result['matched']; - if (C\is_empty($matches)) { - return null; - } - - $matched = $matches[0]['object']; - if ($matched is dict<_, _> || $matched is vec<_>) { - return new WrappedResult(C\count($matched)); - } - - return new WrappedResult(1); - } - - /** - * Returns the maximum depth of the value. - */ - public function depth(): WrappedResult { - $depth = 0; - - $objects = vec[$this->jsonObject]; - while (!C\is_empty($objects)) { - $children = vec[]; - - foreach ($objects as $object) { - if ($object is dict<_, _> || $object is vec<_>) { - foreach ($object as $child) { - $children[] = $child; - } - } - } - - // Each time we enter this while loop, it means we just processed a new level - $depth += 1; - $objects = $children; - } - - return new WrappedResult($depth); - } - - private static function pathMatched(vec $paths, ExplodedPathType $path): bool { - return C\contains($paths, $path); - } - - private static function setPathsToValue(ObjectAtPath $object, vec $paths, mixed $value): mixed { - $jsonObject = $object['object']; - $path = $object['path']; - - // Found something to be replaced! - if (self::pathMatched($paths, $path)) { - return $value; - } - - if ($jsonObject is vec<_> || $jsonObject is dict<_, _>) { - $out = $jsonObject is vec<_> ? vec[] : dict[]; - foreach ($jsonObject as $key => $original_value) { - $childPath = Vec\concat($path, vec[$key]); - $newValue = self::setPathsToValue( - shape('object' => $original_value, 'path' => $childPath), - $paths, - $value, - ); - - if ($out is vec<_>) { - $out[] = $newValue; - } else if ($out is dict<_, _>) { - $out[$key] = $newValue; - } - } - - return $out; - } - - return $jsonObject; - } - - private static function matchArrayIndex(string $jsonPath): ?shape('index' => string, 'rest' => string) { - if (Str\is_empty($jsonPath) || $jsonPath[0] !== self::TOK_SELECTOR_BEGIN) { - return null; - } - - $array_idx_regex = re"/^\[(?\*|\d+)\](?.*)$/"; - $matched = Regex\first_match($jsonPath, $array_idx_regex); - if ($matched) { - return shape( - 'index' => $matched['index'], - 'rest' => $matched['rest'], - ); - } - - return null; - } - - private static function matchChildAccess(string $jsonPath): ?shape('child' => string, 'rest' => string) { - if (Str\is_empty($jsonPath) || $jsonPath[0] != self::TOK_CHILD_ACCESS_BEGIN) { - return null; - } - - $child_name_regex = re"/^\.(?(?\"?)[[:alpha:]_$][a-zA-Z0-9_\-\$]*(?\\2)|\*)(?.*)$/"; - $matched = Regex\first_match($jsonPath, $child_name_regex); - - if ($matched) { - // Remove double quotedness if matched - if ($matched['quote0'] === '"' && $matched['quote1'] === '"') { - $child = Str\strip_prefix($matched['child'], '"') |> Str\strip_suffix($$, '"'); - } else { - $child = $matched['child']; - } - - return shape( - 'child' => $child, - 'rest' => $matched['rest'], - ); - } - - return null; - } - - private static function opChildName(ObjectAtPath $objectAtPath, string $childName): MatchedObjectsResult { - $out = vec[]; - - $jsonObject = $objectAtPath['object']; - $path = $objectAtPath['path']; - - $diverged = false; - - if ($jsonObject is KeyedContainer<_, _>) { - if ($childName === self::TOK_ALL) { - $diverged = true; - - foreach ($jsonObject as $key => $item) { - // We ignore int indices (varray) here - if ($key is int) { - continue; - } - - $out[] = shape( - 'object' => $item, - 'path' => Vec\concat($path, vec[$key]), - ); - } - } else if (C\contains_key($jsonObject, $childName)) { - $out[] = shape( - 'object' => $jsonObject[$childName], - 'path' => Vec\concat($path, vec[$childName]), - ); - } - } - - return shape('matched' => $out, 'divergingPath' => $diverged); - } - - private static function opChildSelector(ObjectAtPath $objectAtPath, string $contents): MatchedObjectsResult { - $jsonObject = $objectAtPath['object']; - $path = $objectAtPath['path']; - - if ($jsonObject is KeyedContainer<_, _>) { - if ($contents === self::TOK_ALL) { - $out = vec[]; - foreach ($jsonObject as $key => $item) { - $out[] = shape( - 'object' => $item, - 'path' => Vec\concat($path, vec[$key]), - ); - } - return shape('matched' => $out, 'divergingPath' => true); - } - - $index = Str\to_int($contents); - if ($index is nonnull) { - if (C\contains_key($jsonObject, $index)) { - return shape( - 'matched' => vec[shape( - 'object' => $jsonObject[$index], - 'path' => Vec\concat($path, vec[$index]), - )], - 'divergingPath' => false, - ); - } - - return shape('matched' => vec[], 'divergingPath' => false); - } - - throw new InvalidJSONPathException($contents); - } - - return shape('matched' => vec[], 'divergingPath' => false); - } - - private static function matchRecursiveSelector( - string $jsonPath, - ): ?shape(?'index' => string, ?'child' => string, 'rest' => string) { - if (!Str\starts_with($jsonPath, self::TOK_DOUBLE_ASTERISK)) { - return null; - } - - $jsonPath = Str\strip_prefix($jsonPath, self::TOK_DOUBLE_ASTERISK); - - $matchedArrayIndex = self::matchArrayIndex($jsonPath); - if ($matchedArrayIndex) { - return shape('index' => $matchedArrayIndex['index'], 'rest' => $matchedArrayIndex['rest']); - } - - $matchedChildAccess = self::matchChildAccess($jsonPath); - if ($matchedChildAccess) { - return shape('child' => $matchedChildAccess['child'], 'rest' => $matchedChildAccess['rest']); - } - - return null; - } - - private static function opRecursiveSelector( - ObjectAtPath $objectAtPath, - shape(?'index' => string, ?'child' => string, 'rest' => string) $matched, - ): MatchedObjectsResult { - $out = vec[]; - - $jsonObject = $objectAtPath['object']; - $path = $objectAtPath['path']; - - $childName = $matched['child'] ?? null; - if ($childName is nonnull) { - $ret = self::opChildName($objectAtPath, $childName); - $out = Vec\concat($out, $ret['matched']); - - if ($jsonObject is KeyedContainer<_, _>) { - foreach ($jsonObject as $key => $value) { - $ret = self::opRecursiveSelector( - shape('object' => $value, 'path' => Vec\concat($path, vec[$key])), - $matched, - ); - $out = Vec\concat($out, $ret['matched']); - } - } - } - - $index = $matched['index'] ?? null; - if ($index is nonnull) { - $ret = self::opChildSelector($objectAtPath, $index); - $out = Vec\concat($out, $ret['matched']); - - if ($jsonObject is KeyedContainer<_, _>) { - foreach ($jsonObject as $key => $value) { - $ret = self::opRecursiveSelector( - shape('object' => $value, 'path' => Vec\concat($path, vec[$key])), - $matched, - ); - $out = Vec\concat($out, $ret['matched']); - } - } - } - - return shape('matched' => $out, 'divergingPath' => true); - } - - static private function getMatching(mixed $jsonObject, string $jsonPath): MatchedObjectsResult { - if (!Str\starts_with($jsonPath, self::TOK_ROOT)) { - throw new InvalidJSONPathException($jsonPath); - } - - $jsonPath = Str\strip_prefix($jsonPath, self::TOK_ROOT); - $selection = vec[shape('object' => $jsonObject, 'path' => vec[self::TOK_ROOT])]; - $divergingPath = false; - while (!Str\is_empty($jsonPath) && !C\is_empty($selection)) { - $newSelection = vec[]; - - $matchedChildAccess = self::matchChildAccess($jsonPath); - if ($matchedChildAccess is nonnull) { - foreach ($selection as $jsonObject) { - $ret = self::opChildName($jsonObject, $matchedChildAccess['child']); - $divergingPath = $divergingPath || $ret['divergingPath']; - $newSelection = Vec\concat($newSelection, $ret['matched']); - } - - if (C\is_empty($newSelection)) { - $selection = vec[]; - break; - } else { - $jsonPath = $matchedChildAccess['rest']; - } - - $selection = $newSelection; - continue; - } - - $matchedArrayIndex = self::matchArrayIndex($jsonPath); - if ($matchedArrayIndex) { - $index = $matchedArrayIndex['index']; - foreach ($selection as $jsonObject) { - $ret = self::opChildSelector($jsonObject, $index); - $divergingPath = $divergingPath || $ret['divergingPath']; - $newSelection = Vec\concat($newSelection, $ret['matched']); - } - - if (C\is_empty($newSelection)) { - $selection = vec[]; - break; - } else { - $jsonPath = $matchedArrayIndex['rest']; - } - - $selection = $newSelection; - continue; - } - - $matchedRecursive = self::matchRecursiveSelector($jsonPath); - if ($matchedRecursive) { - foreach ($selection as $jsonObject) { - $ret = self::opRecursiveSelector($jsonObject, $matchedRecursive); - $divergingPath = $divergingPath || $ret['divergingPath']; - $newSelection = Vec\concat($newSelection, $ret['matched']); - } - - if (C\is_empty($newSelection)) { - $selection = vec[]; - break; - } else { - $jsonPath = $matchedRecursive['rest']; - } - - $selection = $newSelection; - continue; - } - - throw new InvalidJSONPathException($jsonPath); - } - - return shape('matched' => $selection, 'divergingPath' => $divergingPath); - } + // Tokens + const TOK_ROOT = '$'; + const TOK_SELECTOR_BEGIN = '['; + const TOK_SELECTOR_END = ']'; + const TOK_ALL = '*'; + const TOK_CHILD_ACCESS_BEGIN = '.'; + const TOK_DOUBLE_ASTERISK = '**'; + + private mixed $jsonObject; // This is any valid type in JSON (vec, dict, num, string, null) + + /** + * Class constructor. + * If $json is null the json object contained will be initialized empty. + */ + public function __construct(mixed $json = null) { + if ($json is string) { + $this->jsonObject = \json_decode($json, true, 512, \JSON_FB_HACK_ARRAYS); + if ($json !== 'null' && $this->jsonObject === null) { + throw new InvalidJSONException('string does not contain a valid JSON object.'); + } + } else if ($json is vec<_> || $json is dict<_, _> || \is_object($json)) { + // We encode & decode here to make sure we only have nested dicts/vecs & dict keys are string only + $this->jsonObject = \json_decode(\json_encode($json), true, 512, \JSON_FB_HACK_ARRAYS); + } else { + throw new InvalidJSONException('value does not encode a JSON object.'); + } + } + + /** + * Returns the value of the JSON object + */ + public function getValue(): mixed { + return $this->jsonObject; + } + + /** + * Returns the JSON representation of the JSON object + */ + public function getJSON(): string { + return \json_encode($this->jsonObject); + } + + /** + * Returns an vec containing objects that match the JsonPath. + * + * If smartGet was set to true when creating the instance and + * the JsonPath given does not branch, it will return the value + * instead of a vec of size 1. + */ + public function get(string $jsonPath, GetOptions $opts = shape('unwrap' => false)): ?WrappedResult { + $unwrap = $opts['unwrap'] ?? false; + + $jsonObject = $this->jsonObject; + $result = self::getMatching($jsonObject, $jsonPath); + + $matched = $result['matched']; + $divergingPath = $result['divergingPath']; + if (C\is_empty($matched)) { + return null; + } + + if ($unwrap && !$divergingPath) { + return new WrappedResult($matched[0]['object']); + } + + return new WrappedResult(Vec\map($matched, $o ==> $o['object'])); + } + + /** + * Replaces the element that results from the $jsonPath query + * to $value. This method disallows divering (wildcard) paths. + * It is a no-op if the given path doesn't already exist. + * + * This method returns a new JsonObject with the new value. + */ + public function replace(string $jsonPath, mixed $value): WrappedResult { + $result = self::getMatching($this->jsonObject, $jsonPath); + if ($result['divergingPath']) { + throw new DivergentJSONPathSetException('Cannot set a value using a wildcard JSON path'); + } + + $out = self::setPathsToValue( + shape('object' => $this->jsonObject, 'path' => vec[self::TOK_ROOT]), + Vec\map($result['matched'], $m ==> $m['path']), + $value, + ); + + return new WrappedResult(new JSONObject(\json_encode($out))); + } + + /** + * Returns the keys of the object located by the $jsonPath query. + * This method disallows divering (wildcard) paths. + * + * This method returns null (if object not located) or a vec[] of the keys. + */ + public function keys(string $jsonPath = '$'): ?WrappedResult> { + $result = self::getMatching($this->jsonObject, $jsonPath); + if ($result['divergingPath']) { + throw new DivergentJSONPathSetException('Cannot get keys using a wildcard JSON path'); + } + + $matches = $result['matched']; + if (C\is_empty($matches)) { + return null; + } + + $matched = $matches[0]['object']; + if (!($matched is dict<_, _>)) { + return null; + } + + $keys = vec[]; + foreach ($matched as $k => $_v) { + invariant($k is string, 'dict cannot have non string key in JSON'); + $keys[] = $k; + } + + return new WrappedResult($keys); + } + + /** + * Returns the length of the value located by the $jsonPath query. + * This method disallows divering (wildcard) paths. + * + * This method returns null (if object not located) or a vec[] of the keys. + */ + public function length(string $jsonPath = '$'): ?WrappedResult { + $result = self::getMatching($this->jsonObject, $jsonPath); + if ($result['divergingPath']) { + throw new DivergentJSONPathSetException('Cannot get length using a wildcard JSON path'); + } + + $matches = $result['matched']; + if (C\is_empty($matches)) { + return null; + } + + $matched = $matches[0]['object']; + if ($matched is dict<_, _> || $matched is vec<_>) { + return new WrappedResult(C\count($matched)); + } + + return new WrappedResult(1); + } + + /** + * Returns the maximum depth of the value. + */ + public function depth(): WrappedResult { + $depth = 0; + + $objects = vec[$this->jsonObject]; + while (!C\is_empty($objects)) { + $children = vec[]; + + foreach ($objects as $object) { + if ($object is dict<_, _> || $object is vec<_>) { + foreach ($object as $child) { + $children[] = $child; + } + } + } + + // Each time we enter this while loop, it means we just processed a new level + $depth += 1; + $objects = $children; + } + + return new WrappedResult($depth); + } + + private static function pathMatched(vec $paths, ExplodedPathType $path): bool { + return C\contains($paths, $path); + } + + private static function setPathsToValue(ObjectAtPath $object, vec $paths, mixed $value): mixed { + $jsonObject = $object['object']; + $path = $object['path']; + + // Found something to be replaced! + if (self::pathMatched($paths, $path)) { + return $value; + } + + if ($jsonObject is vec<_> || $jsonObject is dict<_, _>) { + $out = $jsonObject is vec<_> ? vec[] : dict[]; + foreach ($jsonObject as $key => $original_value) { + $childPath = Vec\concat($path, vec[$key]); + $newValue = self::setPathsToValue( + shape('object' => $original_value, 'path' => $childPath), + $paths, + $value, + ); + + if ($out is vec<_>) { + $out[] = $newValue; + } else if ($out is dict<_, _>) { + $out[$key] = $newValue; + } + } + + return $out; + } + + return $jsonObject; + } + + private static function matchArrayIndex(string $jsonPath): ?shape('index' => string, 'rest' => string) { + if (Str\is_empty($jsonPath) || $jsonPath[0] !== self::TOK_SELECTOR_BEGIN) { + return null; + } + + $array_idx_regex = re"/^\[(?\*|\d+)\](?.*)$/"; + $matched = Regex\first_match($jsonPath, $array_idx_regex); + if ($matched) { + return shape( + 'index' => $matched['index'], + 'rest' => $matched['rest'], + ); + } + + return null; + } + + private static function matchChildAccess(string $jsonPath): ?shape('child' => string, 'rest' => string) { + if (Str\is_empty($jsonPath) || $jsonPath[0] != self::TOK_CHILD_ACCESS_BEGIN) { + return null; + } + + $child_name_regex = re"/^\.(?(?\"?)[[:alpha:]_$][a-zA-Z0-9_\-\$]*(?\\2)|\*)(?.*)$/"; + $matched = Regex\first_match($jsonPath, $child_name_regex); + + if ($matched) { + // Remove double quotedness if matched + if ($matched['quote0'] === '"' && $matched['quote1'] === '"') { + $child = Str\strip_prefix($matched['child'], '"') |> Str\strip_suffix($$, '"'); + } else { + $child = $matched['child']; + } + + return shape( + 'child' => $child, + 'rest' => $matched['rest'], + ); + } + + return null; + } + + private static function opChildName(ObjectAtPath $objectAtPath, string $childName): MatchedObjectsResult { + $out = vec[]; + + $jsonObject = $objectAtPath['object']; + $path = $objectAtPath['path']; + + $diverged = false; + + if ($jsonObject is KeyedContainer<_, _>) { + if ($childName === self::TOK_ALL) { + $diverged = true; + + foreach ($jsonObject as $key => $item) { + // We ignore int indices (varray) here + if ($key is int) { + continue; + } + + $out[] = shape( + 'object' => $item, + 'path' => Vec\concat($path, vec[$key]), + ); + } + } else if (C\contains_key($jsonObject, $childName)) { + $out[] = shape( + 'object' => $jsonObject[$childName], + 'path' => Vec\concat($path, vec[$childName]), + ); + } + } + + return shape('matched' => $out, 'divergingPath' => $diverged); + } + + private static function opChildSelector(ObjectAtPath $objectAtPath, string $contents): MatchedObjectsResult { + $jsonObject = $objectAtPath['object']; + $path = $objectAtPath['path']; + + if ($jsonObject is KeyedContainer<_, _>) { + if ($contents === self::TOK_ALL) { + $out = vec[]; + foreach ($jsonObject as $key => $item) { + $out[] = shape( + 'object' => $item, + 'path' => Vec\concat($path, vec[$key]), + ); + } + return shape('matched' => $out, 'divergingPath' => true); + } + + $index = Str\to_int($contents); + if ($index is nonnull) { + if (C\contains_key($jsonObject, $index)) { + return shape( + 'matched' => vec[shape( + 'object' => $jsonObject[$index], + 'path' => Vec\concat($path, vec[$index]), + )], + 'divergingPath' => false, + ); + } + + return shape('matched' => vec[], 'divergingPath' => false); + } + + throw new InvalidJSONPathException($contents); + } + + return shape('matched' => vec[], 'divergingPath' => false); + } + + private static function matchRecursiveSelector( + string $jsonPath, + ): ?shape(?'index' => string, ?'child' => string, 'rest' => string) { + if (!Str\starts_with($jsonPath, self::TOK_DOUBLE_ASTERISK)) { + return null; + } + + $jsonPath = Str\strip_prefix($jsonPath, self::TOK_DOUBLE_ASTERISK); + + $matchedArrayIndex = self::matchArrayIndex($jsonPath); + if ($matchedArrayIndex) { + return shape('index' => $matchedArrayIndex['index'], 'rest' => $matchedArrayIndex['rest']); + } + + $matchedChildAccess = self::matchChildAccess($jsonPath); + if ($matchedChildAccess) { + return shape('child' => $matchedChildAccess['child'], 'rest' => $matchedChildAccess['rest']); + } + + return null; + } + + private static function opRecursiveSelector( + ObjectAtPath $objectAtPath, + shape(?'index' => string, ?'child' => string, 'rest' => string) $matched, + ): MatchedObjectsResult { + $out = vec[]; + + $jsonObject = $objectAtPath['object']; + $path = $objectAtPath['path']; + + $childName = $matched['child'] ?? null; + if ($childName is nonnull) { + $ret = self::opChildName($objectAtPath, $childName); + $out = Vec\concat($out, $ret['matched']); + + if ($jsonObject is KeyedContainer<_, _>) { + foreach ($jsonObject as $key => $value) { + $ret = self::opRecursiveSelector( + shape('object' => $value, 'path' => Vec\concat($path, vec[$key])), + $matched, + ); + $out = Vec\concat($out, $ret['matched']); + } + } + } + + $index = $matched['index'] ?? null; + if ($index is nonnull) { + $ret = self::opChildSelector($objectAtPath, $index); + $out = Vec\concat($out, $ret['matched']); + + if ($jsonObject is KeyedContainer<_, _>) { + foreach ($jsonObject as $key => $value) { + $ret = self::opRecursiveSelector( + shape('object' => $value, 'path' => Vec\concat($path, vec[$key])), + $matched, + ); + $out = Vec\concat($out, $ret['matched']); + } + } + } + + return shape('matched' => $out, 'divergingPath' => true); + } + + static private function getMatching(mixed $jsonObject, string $jsonPath): MatchedObjectsResult { + if (!Str\starts_with($jsonPath, self::TOK_ROOT)) { + throw new InvalidJSONPathException($jsonPath); + } + + $jsonPath = Str\strip_prefix($jsonPath, self::TOK_ROOT); + $selection = vec[shape('object' => $jsonObject, 'path' => vec[self::TOK_ROOT])]; + $divergingPath = false; + while (!Str\is_empty($jsonPath) && !C\is_empty($selection)) { + $newSelection = vec[]; + + $matchedChildAccess = self::matchChildAccess($jsonPath); + if ($matchedChildAccess is nonnull) { + foreach ($selection as $jsonObject) { + $ret = self::opChildName($jsonObject, $matchedChildAccess['child']); + $divergingPath = $divergingPath || $ret['divergingPath']; + $newSelection = Vec\concat($newSelection, $ret['matched']); + } + + if (C\is_empty($newSelection)) { + $selection = vec[]; + break; + } else { + $jsonPath = $matchedChildAccess['rest']; + } + + $selection = $newSelection; + continue; + } + + $matchedArrayIndex = self::matchArrayIndex($jsonPath); + if ($matchedArrayIndex) { + $index = $matchedArrayIndex['index']; + foreach ($selection as $jsonObject) { + $ret = self::opChildSelector($jsonObject, $index); + $divergingPath = $divergingPath || $ret['divergingPath']; + $newSelection = Vec\concat($newSelection, $ret['matched']); + } + + if (C\is_empty($newSelection)) { + $selection = vec[]; + break; + } else { + $jsonPath = $matchedArrayIndex['rest']; + } + + $selection = $newSelection; + continue; + } + + $matchedRecursive = self::matchRecursiveSelector($jsonPath); + if ($matchedRecursive) { + foreach ($selection as $jsonObject) { + $ret = self::opRecursiveSelector($jsonObject, $matchedRecursive); + $divergingPath = $divergingPath || $ret['divergingPath']; + $newSelection = Vec\concat($newSelection, $ret['matched']); + } + + if (C\is_empty($newSelection)) { + $selection = vec[]; + break; + } else { + $jsonPath = $matchedRecursive['rest']; + } + + $selection = $newSelection; + continue; + } + + throw new InvalidJSONPathException($jsonPath); + } + + return shape('matched' => $selection, 'divergingPath' => $divergingPath); + } } diff --git a/src/Metrics.php b/src/Metrics.php index ca88219..034660d 100644 --- a/src/Metrics.php +++ b/src/Metrics.php @@ -5,25 +5,25 @@ use namespace HH\Lib\{C, Str, Vec}; enum QueryType: string { - INSERT = 'insert'; - SELECT = 'select'; - UPDATE = 'update'; - DELETE = 'delete'; + INSERT = 'insert'; + SELECT = 'select'; + UPDATE = 'update'; + DELETE = 'delete'; } type query_counts = shape( - QueryType::INSERT => int, - QueryType::SELECT => int, - QueryType::UPDATE => int, - QueryType::DELETE => int, + QueryType::INSERT => int, + QueryType::SELECT => int, + QueryType::UPDATE => int, + QueryType::DELETE => int, ); type query_log = shape( - 'type' => QueryType, - 'host' => string, - 'table' => string, - 'sql' => string, - ?'callstack' => string, + 'type' => QueryType, + 'host' => string, + 'table' => string, + 'sql' => string, + ?'callstack' => string, ); /** @@ -31,152 +31,153 @@ enum QueryType: string { */ abstract final class Metrics { - /** - * Recording call stacks for each query gets expensive, - * only turn this on if you have a good use for it - */ - public static bool $enableCallstacks = false; - public static bool $enable = false; - - /** - * Filter out function names matching these patterns from the beginning of your callstack to make the stacks more concise - * uses fnmatch() syntax - */ - public static keyset $stackIgnorePatterns = keyset[]; - public static vec $queryMetrics = vec[]; - - public static function getQueryMetrics(): vec { - return self::$queryMetrics; - } - - public static function reset(): void { - self::$queryMetrics = vec[]; - self::$enable = false; - } - - public static function getCountByQueryType(): query_counts { - $totals = shape( - QueryType::SELECT => 0, - QueryType::INSERT => 0, - QueryType::DELETE => 0, - QueryType::UPDATE => 0, - ); - - foreach (self::$queryMetrics as $metric) { - switch ($metric['type']) { - case QueryType::SELECT: - $totals[QueryType::SELECT]++; - break; - case QueryType::INSERT: - $totals[QueryType::INSERT]++; - break; - case QueryType::DELETE: - $totals[QueryType::DELETE]++; - break; - case QueryType::UPDATE: - $totals[QueryType::UPDATE]++; - break; - } - } - - return $totals; - } - - public static function getTotalQueryCount(): int { - return C\count(self::$queryMetrics); - } - - /** - * Log a query - * While a query may hit multiple tables, we only include the first one currently - */ - public static function trackQuery(QueryType $type, string $host, string $table_name, string $sql): void { - - if (!self::$enable){ - return; - } - - $metric = shape( - 'type' => $type, - 'host' => $host, - 'table' => $table_name, - 'sql' => $sql, - ); - - if (self::$enableCallstacks) { - $metric['callstack'] = self::getBacktrace(); - } - - self::$queryMetrics[] = $metric; - } - - protected static function getBacktrace(): string { - $trace = \debug_backtrace(\DEBUG_BACKTRACE_IGNORE_ARGS); - while (!C\is_empty($trace)) { - $matched = false; - - // filter out this library - if ( - Str\contains($trace[0]['file'] ?? '', \realpath(__DIR__.'/..')) || ($trace[0]['class'] ?? '') === AsyncMysqlConnection::class - ) { - $trace = Vec\drop($trace, 1); - continue; - } - - // filter out ignored patterns - foreach (self::$stackIgnorePatterns as $pattern) { - if (\fnmatch($pattern, $trace[0]['function'] ?? '')) { - $trace = Vec\drop($trace, 1); - $matched = true; - break; - } - } - - // as soon as we find an item in the trace that isn't in the ignore list, we're done - if (!$matched) { - break; - } - } - - return Vec\reverse($trace) - |> Vec\map($$, $entry ==> self::formatStackEntry($entry)) - |> Str\join($$, ' -> '); - } - - /** - * Returns something like my_file.php:123#my_function() - */ - protected static function formatStackEntry( - shape('function' => ?string, 'class' => ?string, ?'file' => string, ?'line' => int) $entry, - ): string { - - $file = $entry['file'] ?? ''; - $line = $entry['line'] ?? null; - $function = $entry['function'] ?? ''; - $class = $entry['class'] ?? ''; - - $formatted = ''; - - // my_file.php - if (!Str\is_empty($file)) { - $formatted = Str\split($file, '/') |> $$[C\count($$) - 1]; - } - - // :123 - if ($line is nonnull) { - $formatted .= ':'.$line; - } - - // Foo::function() or my_function() - if (!Str\is_empty($function)) { - if (!Str\is_empty($class)) { - $formatted .= '#'.$class.'::'.$function.'()'; - } else { - $formatted .= '#'.$function.'()'; - } - } - - return $formatted; - } + /** + * Recording call stacks for each query gets expensive, + * only turn this on if you have a good use for it + */ + public static bool $enableCallstacks = false; + public static bool $enable = false; + + /** + * Filter out function names matching these patterns from the beginning of your callstack to make the stacks more concise + * uses fnmatch() syntax + */ + public static keyset $stackIgnorePatterns = keyset[]; + public static vec $queryMetrics = vec[]; + + public static function getQueryMetrics(): vec { + return self::$queryMetrics; + } + + public static function reset(): void { + self::$queryMetrics = vec[]; + self::$enable = false; + } + + public static function getCountByQueryType(): query_counts { + $totals = shape( + QueryType::SELECT => 0, + QueryType::INSERT => 0, + QueryType::DELETE => 0, + QueryType::UPDATE => 0, + ); + + foreach (self::$queryMetrics as $metric) { + switch ($metric['type']) { + case QueryType::SELECT: + $totals[QueryType::SELECT]++; + break; + case QueryType::INSERT: + $totals[QueryType::INSERT]++; + break; + case QueryType::DELETE: + $totals[QueryType::DELETE]++; + break; + case QueryType::UPDATE: + $totals[QueryType::UPDATE]++; + break; + } + } + + return $totals; + } + + public static function getTotalQueryCount(): int { + return C\count(self::$queryMetrics); + } + + /** + * Log a query + * While a query may hit multiple tables, we only include the first one currently + */ + public static function trackQuery(QueryType $type, string $host, string $table_name, string $sql): void { + + if (!self::$enable) { + return; + } + + $metric = shape( + 'type' => $type, + 'host' => $host, + 'table' => $table_name, + 'sql' => $sql, + ); + + if (self::$enableCallstacks) { + $metric['callstack'] = self::getBacktrace(); + } + + self::$queryMetrics[] = $metric; + } + + protected static function getBacktrace(): string { + $trace = \debug_backtrace(\DEBUG_BACKTRACE_IGNORE_ARGS); + while (!C\is_empty($trace)) { + $matched = false; + + // filter out this library + if ( + Str\contains($trace[0]['file'] ?? '', \realpath(__DIR__.'/..')) || + ($trace[0]['class'] ?? '') === AsyncMysqlConnection::class + ) { + $trace = Vec\drop($trace, 1); + continue; + } + + // filter out ignored patterns + foreach (self::$stackIgnorePatterns as $pattern) { + if (\fnmatch($pattern, $trace[0]['function'] ?? '')) { + $trace = Vec\drop($trace, 1); + $matched = true; + break; + } + } + + // as soon as we find an item in the trace that isn't in the ignore list, we're done + if (!$matched) { + break; + } + } + + return Vec\reverse($trace) + |> Vec\map($$, $entry ==> self::formatStackEntry($entry)) + |> Str\join($$, ' -> '); + } + + /** + * Returns something like my_file.php:123#my_function() + */ + protected static function formatStackEntry( + shape('function' => ?string, 'class' => ?string, ?'file' => string, ?'line' => int) $entry, + ): string { + + $file = $entry['file'] ?? ''; + $line = $entry['line'] ?? null; + $function = $entry['function'] ?? ''; + $class = $entry['class'] ?? ''; + + $formatted = ''; + + // my_file.php + if (!Str\is_empty($file)) { + $formatted = Str\split($file, '/') |> $$[C\count($$) - 1]; + } + + // :123 + if ($line is nonnull) { + $formatted .= ':'.$line; + } + + // Foo::function() or my_function() + if (!Str\is_empty($function)) { + if (!Str\is_empty($class)) { + $formatted .= '#'.$class.'::'.$function.'()'; + } else { + $formatted .= '#'.$function.'()'; + } + } + + return $formatted; + } } diff --git a/src/Parser/CreateTableParser.php b/src/Parser/CreateTableParser.php index ec37e2a..7b2f947 100644 --- a/src/Parser/CreateTableParser.php +++ b/src/Parser/CreateTableParser.php @@ -228,11 +228,8 @@ private function walk(vec $tokens, string $sql, vec<(int, int)> $source_ if (C\count($temp)) { $statements[] = shape( 'tuples' => $temp, - 'sql' => Str\slice( - $sql, - $source_map[$start][0], - $source_map[$i][0] - $source_map[$start][0] + $source_map[$i][1], - ), + 'sql' => + Str\slice($sql, $source_map[$start][0], $source_map[$i][0] - $source_map[$start][0] + $source_map[$i][1]), ); } $temp = vec[]; @@ -244,11 +241,8 @@ private function walk(vec $tokens, string $sql, vec<(int, int)> $source_ if (C\count($temp)) { $statements[] = shape( 'tuples' => $temp, - 'sql' => Str\slice( - $sql, - $source_map[$start][0], - $source_map[$i][0] - $source_map[$start][0] + $source_map[$i][1], - ), + 'sql' => + Str\slice($sql, $source_map[$start][0], $source_map[$i][0] - $source_map[$start][0] + $source_map[$i][1]), ); } @@ -859,10 +853,7 @@ private function extractTokens(string $sql, vec<(int, int)> $source_map): vec CLAUSE_ORDER = dict[ - 'DELETE' => 1, - 'FROM' => 2, - 'WHERE' => 3, - 'ORDER' => 4, - 'LIMIT' => 5, - ]; + const dict CLAUSE_ORDER = dict[ + 'DELETE' => 1, + 'FROM' => 2, + 'WHERE' => 3, + 'ORDER' => 4, + 'LIMIT' => 5, + ]; - private string $currentClause = 'DELETE'; - private int $pointer = 0; + private string $currentClause = 'DELETE'; + private int $pointer = 0; - public function __construct(private token_list $tokens, private string $sql) {} + public function __construct(private token_list $tokens, private string $sql) {} - public function parse(): DeleteQuery { + public function parse(): DeleteQuery { - // if we got here, the first token had better be a DELETE - if ($this->tokens[$this->pointer]['value'] !== 'DELETE') { - throw new SQLFakeParseException('Parser error: expected DELETE'); - } - $this->pointer++; - $count = C\count($this->tokens); + // if we got here, the first token had better be a DELETE + if ($this->tokens[$this->pointer]['value'] !== 'DELETE') { + throw new SQLFakeParseException('Parser error: expected DELETE'); + } + $this->pointer++; + $count = C\count($this->tokens); - $query = new DeleteQuery($this->sql); + $query = new DeleteQuery($this->sql); - while ($this->pointer < $count) { - $token = $this->tokens[$this->pointer]; + while ($this->pointer < $count) { + $token = $this->tokens[$this->pointer]; - switch ($token['type']) { - case TokenType::CLAUSE: - // make sure clauses are in order - if ( - C\contains_key(self::CLAUSE_ORDER, $token['value']) && - self::CLAUSE_ORDER[$this->currentClause] >= self::CLAUSE_ORDER[$token['value']] - ) { - throw new SQLFakeParseException("Unexpected clause {$token['value']}"); - } - $this->currentClause = $token['value']; - switch ($token['value']) { - case 'FROM': - $this->pointer++; - $token = $this->tokens[$this->pointer]; - if ($token === null || $token['type'] !== TokenType::IDENTIFIER) { - throw new SQLFakeParseException('Expected table name after FROM'); - } - $table = shape( - 'name' => $token['value'], - 'join_type' => JoinType::JOIN, - ); - $query->fromClause = $table; - $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); - break; - case 'WHERE': - $expression_parser = new ExpressionParser($this->tokens, $this->pointer); - list($this->pointer, $expression) = $expression_parser->buildWithPointer(); - $query->whereClause = $expression; - break; - case 'ORDER': - $p = new OrderByParser($this->pointer, $this->tokens); - list($this->pointer, $query->orderBy) = $p->parse(); - break; - case 'LIMIT': - $p = new LimitParser($this->pointer, $this->tokens); - list($this->pointer, $query->limitClause) = $p->parse(); - break; - default: - throw new SQLFakeParseException("Unexpected clause {$token['value']}"); - } - break; - case TokenType::RESERVED: - case TokenType::IDENTIFIER: - // just skip over these hints - if ( - $this->currentClause === 'DELETE' && - C\contains_key(keyset['LOW_PRIORITY', 'QUICK', 'IGNORE'], $token['value']) - ) { - break; - } - if ($this->currentClause === 'DELETE' && $token['type'] === TokenType::IDENTIFIER) { - // delete without FROM - $table = shape( - 'name' => $token['value'], - 'join_type' => JoinType::JOIN, - ); - $query->fromClause = $table; - $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); - $this->currentClause = 'FROM'; - break; - } - throw new SQLFakeParseException("Unexpected token {$token['value']}"); - case TokenType::SEPARATOR: - // a semicolon to end the query is valid, but nothing else is in this context - if ($token['value'] !== ';') { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - break; - default: - throw new SQLFakeParseException("Unexpected token {$token['value']}"); - } + switch ($token['type']) { + case TokenType::CLAUSE: + // make sure clauses are in order + if ( + C\contains_key(self::CLAUSE_ORDER, $token['value']) && + self::CLAUSE_ORDER[$this->currentClause] >= self::CLAUSE_ORDER[$token['value']] + ) { + throw new SQLFakeParseException("Unexpected clause {$token['value']}"); + } + $this->currentClause = $token['value']; + switch ($token['value']) { + case 'FROM': + $this->pointer++; + $token = $this->tokens[$this->pointer]; + if ($token === null || $token['type'] !== TokenType::IDENTIFIER) { + throw new SQLFakeParseException('Expected table name after FROM'); + } + $table = shape( + 'name' => $token['value'], + 'join_type' => JoinType::JOIN, + ); + $query->fromClause = $table; + $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); + break; + case 'WHERE': + $expression_parser = new ExpressionParser($this->tokens, $this->pointer); + list($this->pointer, $expression) = $expression_parser->buildWithPointer(); + $query->whereClause = $expression; + break; + case 'ORDER': + $p = new OrderByParser($this->pointer, $this->tokens); + list($this->pointer, $query->orderBy) = $p->parse(); + break; + case 'LIMIT': + $p = new LimitParser($this->pointer, $this->tokens); + list($this->pointer, $query->limitClause) = $p->parse(); + break; + default: + throw new SQLFakeParseException("Unexpected clause {$token['value']}"); + } + break; + case TokenType::RESERVED: + case TokenType::IDENTIFIER: + // just skip over these hints + if ( + $this->currentClause === 'DELETE' && + C\contains_key(keyset['LOW_PRIORITY', 'QUICK', 'IGNORE'], $token['value']) + ) { + break; + } + if ($this->currentClause === 'DELETE' && $token['type'] === TokenType::IDENTIFIER) { + // delete without FROM + $table = shape( + 'name' => $token['value'], + 'join_type' => JoinType::JOIN, + ); + $query->fromClause = $table; + $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); + $this->currentClause = 'FROM'; + break; + } + throw new SQLFakeParseException("Unexpected token {$token['value']}"); + case TokenType::SEPARATOR: + // a semicolon to end the query is valid, but nothing else is in this context + if ($token['value'] !== ';') { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + break; + default: + throw new SQLFakeParseException("Unexpected token {$token['value']}"); + } - $this->pointer++; - } + $this->pointer++; + } - if ($query->fromClause === null) { - throw new SQLFakeParseException('Expected FROM in DELETE statement'); - } + if ($query->fromClause === null) { + throw new SQLFakeParseException('Expected FROM in DELETE statement'); + } - return $query; - } + return $query; + } } diff --git a/src/Parser/ExpressionParser.php b/src/Parser/ExpressionParser.php index f0461dd..0e4b715 100644 --- a/src/Parser/ExpressionParser.php +++ b/src/Parser/ExpressionParser.php @@ -17,495 +17,495 @@ */ final class ExpressionParser { - /** This represents the precedence of operators. Higher numbers = higher precedence - * UNARY_MINUS is one we actually rename because it has the same token as subtraction - * https://dev.mysql.com/doc/refman/5.7/en/operator-precedence.html - */ - const dict OPERATOR_PRECEDENCE = dict[ - 'INTERVAL' => 17, // date arithmetic - 'BINARY' => 16, // cast to a binary string - 'COLLATE' => 16, // set collation - '!' => 15, // NOT - 'UNARY_MINUS' => 14, // unary minus... to keep its name separate from - - 'UNARY_PLUS' => 14, // unary minus... to keep its name separate from - - '~' => 14, // unary bit inversion - '^' => 13, // bitwise XOR - '*' => 12, // multiplication - '/' => 12, // division - 'DIV' => 12, // integer division - '%' => 12, // modulus - 'MOD' => 12, // modulus - '-' => 11, // subtraction - '+' => 11, // addition - '<<' => 10, // left bit shift - '>>' => 10, // right bit shift - '&' => 9, // bitwise AND - '|' => 8, // bitwise OR - '=' => 7, // comparison - '<=>' => 7, // null-safe equality - '>=' => 7, // greater than or equal - '>' => 7, // greater than - '<=' => 7, // less than or equal - '<' => 7, // less than - '<>' => 7, // not equal - '!=' => 7, // not equal - 'IS' => 7, // boolean equal - 'LIKE' => 7, // string comparison - 'REGEXP' => 7, // regular expression string comparison - 'IN' => 7, // in-list comparison - 'BETWEEN' => 6, // between two values - 'CASE' => 6, // CASE statement - 'WHEN' => 6, // CASE statement - 'THEN' => 6, // CASE statement - 'ELSE' => 6, // CASE statement - 'NOT' => 5, // negation - 'AND' => 4, // boolean AND - '&&' => 4, // boolean AND - 'XOR' => 3, // boolean XOR - 'OR' => 2, // boolean OR - '||' => 2, // boolean OR - 'ASSIGNMENT' => 1, // assigning variables - ]; - - private ?vec $selectExpressions; - private Expression $expression; - - /** - * Most of the time, only the first two props are set. List of tokens to iterate on, and column references to index into. - * - * The rest are set when we recurse to parse an expression that has higher precedence than the current expression - * The parent passes tokens and pointer state to the child, which continues until it finds a lower precedence operator, - * Then returns its state back to the parent (buildWithPointer to move the pointer forward) - */ - public function __construct( - private token_list $tokens, - private int $pointer = -1, - ?Expression $expression = null, - public int $min_precedence = 0, - private bool $is_child = false, - ) { - if ($expression is null) { + /** This represents the precedence of operators. Higher numbers = higher precedence + * UNARY_MINUS is one we actually rename because it has the same token as subtraction + * https://dev.mysql.com/doc/refman/5.7/en/operator-precedence.html + */ + const dict OPERATOR_PRECEDENCE = dict[ + 'INTERVAL' => 17, // date arithmetic + 'BINARY' => 16, // cast to a binary string + 'COLLATE' => 16, // set collation + '!' => 15, // NOT + 'UNARY_MINUS' => 14, // unary minus... to keep its name separate from - + 'UNARY_PLUS' => 14, // unary minus... to keep its name separate from - + '~' => 14, // unary bit inversion + '^' => 13, // bitwise XOR + '*' => 12, // multiplication + '/' => 12, // division + 'DIV' => 12, // integer division + '%' => 12, // modulus + 'MOD' => 12, // modulus + '-' => 11, // subtraction + '+' => 11, // addition + '<<' => 10, // left bit shift + '>>' => 10, // right bit shift + '&' => 9, // bitwise AND + '|' => 8, // bitwise OR + '=' => 7, // comparison + '<=>' => 7, // null-safe equality + '>=' => 7, // greater than or equal + '>' => 7, // greater than + '<=' => 7, // less than or equal + '<' => 7, // less than + '<>' => 7, // not equal + '!=' => 7, // not equal + 'IS' => 7, // boolean equal + 'LIKE' => 7, // string comparison + 'REGEXP' => 7, // regular expression string comparison + 'IN' => 7, // in-list comparison + 'BETWEEN' => 6, // between two values + 'CASE' => 6, // CASE statement + 'WHEN' => 6, // CASE statement + 'THEN' => 6, // CASE statement + 'ELSE' => 6, // CASE statement + 'NOT' => 5, // negation + 'AND' => 4, // boolean AND + '&&' => 4, // boolean AND + 'XOR' => 3, // boolean XOR + 'OR' => 2, // boolean OR + '||' => 2, // boolean OR + 'ASSIGNMENT' => 1, // assigning variables + ]; + + private ?vec $selectExpressions; + private Expression $expression; + + /** + * Most of the time, only the first two props are set. List of tokens to iterate on, and column references to index into. + * + * The rest are set when we recurse to parse an expression that has higher precedence than the current expression + * The parent passes tokens and pointer state to the child, which continues until it finds a lower precedence operator, + * Then returns its state back to the parent (buildWithPointer to move the pointer forward) + */ + public function __construct( + private token_list $tokens, + private int $pointer = -1, + ?Expression $expression = null, + public int $min_precedence = 0, + private bool $is_child = false, + ) { + if ($expression is null) { $expression = new PlaceholderExpression(); - } + } $this->expression = $expression; - } - - /** parses an expression that is inside a delimited list, such as function arguments or row expressions - * i.e.: [col1, col2, col3] - * return tuple because of a bool indicating whether the "DISTINCT" expression was found - */ - private function getListExpression(token_list $tokens): (bool, vec) { - $distinct = false; - - $pos = 0; - $token_count = C\count($tokens); - $needs_comma = false; - - $args = vec[]; - - while ($pos < $token_count) { - $arg = $tokens[$pos]; - - if ($arg['value'] === 'DISTINCT' || $arg['value'] === 'DISTINCTROW') { - $distinct = true; - $pos++; - // DISTINCT can pretend to be a function, unroll it here - if ($tokens[$pos]['type'] === TokenType::PAREN) { - $close = SQLParser::findMatchingParen($pos, $tokens); - $pos++; - $t = $tokens[$pos]; - if ($close - $pos !== 1) { - throw new SQLFakeParseException('Parse error near DISTINCT'); - } - $p = new ExpressionParser(vec[$t], -1); - $expr = $p->build(); - $args[] = $expr; - $pos += 2; - } - continue; - } - - if ($arg['value'] === ',') { - if ($needs_comma) { - $needs_comma = false; - $pos++; - continue; - } else { - throw new SQLFakeParseException('Unexpected comma in SQL query'); - } - } - $p = new ExpressionParser($tokens, $pos - 1); - list($pos, $expr) = $p->buildWithPointer(); - $args[] = $expr; - $pos++; - $needs_comma = true; - } - - return tuple($distinct, $args); - } - - // sometimes we identify a token is of a type that it can be immediately turned into an Expression. do that if so - public function tokenToExpression(token $token): Expression { - - switch ($token['type']) { - case TokenType::NUMERIC_CONSTANT: - case TokenType::BOOLEAN_CONSTANT: - case TokenType::STRING_CONSTANT: - case TokenType::NULL_CONSTANT: - return new ConstantExpression($token); - case TokenType::IDENTIFIER: - // if we are processing an expression in the GROUP BY or HAVING or ORDER BY, check the select list first - // this is because the select may define aliases we can use in these clauses - // i.e. SELECT something as foo ... GROUP BY foo - if ($this->selectExpressions is nonnull) { - foreach ($this->selectExpressions as $expr) { - if ($expr->name === $token['value']) { - return $expr; - } - } - } - return new ColumnExpression($token); - case TokenType::SQLFUNCTION: - // function token... just need to resolve the args - - // next token has to be a paren - $next = $this->nextToken() as nonnull; - // invariant not exception here because the parser wouldn't have seen it as a function without this - invariant($next['type'] === TokenType::PAREN, 'function is be followed by parentheses'); - $closing_paren_pointer = SQLParser::findMatchingParen($this->pointer, $this->tokens); - - // process tokens inside the function arguments - $arg_tokens = Vec\slice($this->tokens, $this->pointer + 1, $closing_paren_pointer - $this->pointer - 1); - list($distinct, $args) = $this->getListExpression($arg_tokens); - - // move the pointer forward to the end of the parentheses - $this->pointer = $closing_paren_pointer; - - $fn = Str\starts_with($token['value'], 'JSON_') - ? new JSONFunctionExpression($token, $args, $distinct) - : new FunctionExpression($token, $args, $distinct); - - return $fn; - default: - throw new SQLFakeNotImplementedException("Not implemented: {$token['value']}"); - } - } - - /** - * The main top level API of this class. Builds a nested, evaluatable expression - */ - public function build(): Expression { - $token = $this->nextToken(); - while ($token !== null) { - switch ($token['type']) { - case TokenType::PAREN: - - $close = SQLParser::findMatchingParen($this->pointer, $this->tokens); - $arg_tokens = Vec\slice($this->tokens, $this->pointer + 1, $close - $this->pointer - 1); - if (!C\count($arg_tokens)) { - throw new SQLFakeParseException('Empty parentheses found'); - } - $this->pointer = $close; - $expr = new PlaceholderExpression(); - - if ($arg_tokens[0]['value'] === 'SELECT') { - $subquery_sql = Vec\map($arg_tokens, $token ==> $token['value']) |> Str\join($$, ' '); - $parser = new SelectParser(0, $arg_tokens, $subquery_sql); - list($p, $select) = $parser->parse(); - $expr = new SubqueryExpression($select, ''); - } else if ($this->expression is InOperatorExpression) { - $pointer = -1; - $in_list = vec[]; - $token_count = C\count($arg_tokens); - while ($pointer < $token_count) { - $p = new ExpressionParser($arg_tokens, $pointer); - list($pointer, $expr) = $p->buildWithPointer(); - $in_list[] = $expr; - - if ($pointer + 1 >= $token_count) { - break; - } - - $pointer++; - $next = $arg_tokens[$pointer]; - if ($next['value'] !== ',') { - throw new SQLFakeParseException('Expected , in IN () list'); - } - } - $this->expression as InOperatorExpression; - $this->expression->setInList($in_list); - // we just called set_in_list so skip the setNextChild call - break; - } else { - // in this context, we could either have a sub-expression like - // (a = b AND c=d) OR a=d - // - // or we could have a LIST element, like (a, b, c) < (x, y, z) - // - - // look for a row expression like (col1, col2, col3) - $second_token = $arg_tokens[1] ?? null; - if ($second_token !== null && $second_token['type'] === TokenType::SEPARATOR) { - list($distinct, $elements) = $this->getListExpression($arg_tokens); - if ($distinct) { - throw new SQLFakeParseException('Unexpected DISTINCT in row expression'); - } - - $expr = new RowExpression($elements); - } else { - $p = new ExpressionParser($arg_tokens, -1); - $expr = $p->build(); - } - } - - if ($this->expression is PlaceholderExpression) { - $this->expression = new BinaryOperatorExpression($expr); - } else { - $this->expression->setNextChild($expr); - } - break; - case TokenType::NUMERIC_CONSTANT: - case TokenType::BOOLEAN_CONSTANT: - case TokenType::NULL_CONSTANT: - case TokenType::STRING_CONSTANT: - case TokenType::SQLFUNCTION: - case TokenType::IDENTIFIER: - - // these token types are all not operands and can be parsed on their own into an expression. - // do that, then let the current expression object figure out where to put it based on its state - $expr = $this->tokenToExpression($token); - - if ($this->expression is PlaceholderExpression) { - // we assume an expression will be a Binary Operator (most common) when we first encounter a token - // if it's something else like BETWEEN or IN, we convert it to that deliberately when encountering the operand - $this->expression = new BinaryOperatorExpression($expr); - } else if ( - $this->expression->operator === null && - $this->expression is BinaryOperatorExpression && - $token['type'] === TokenType::IDENTIFIER - ) { - // we encountered an identifier immediately after the left side of an expression... only hope is that this is an implicit alias like - // "SELECT 1 foo" - // if so, move the pointer back one and return - $this->pointer--; - return $this->expression->left; - } else { - $this->expression->setNextChild($expr); - } - break; - case TokenType::OPERATOR: - // mysql is case insensitive, but php is case sensitive so just uppercase operators for comparisons - $operator = $token['value'] as Operator; - - if ($operator === Operator::CASE) { - if (!($this->expression is PlaceholderExpression)) { - // we encountered a CASE statement inside another expression - // such as "column_name = CASE when ... end" - // so make a new instance to parse the inner CASE, we'll return back here after the END - // we move the pointer back by 1 so that the child class encounters the CASE again with a PlaceholderExpression - $this->pointer = $this->expression - ->addRecursiveExpression($this->tokens, $this->pointer - 1); - break; - } - $this->expression = new CaseOperatorExpression($token); - break; - } else if (C\contains_key(keyset['WHEN', 'THEN', 'ELSE', 'END'], $operator)) { - if (!($this->expression is CaseOperatorExpression)) { - if ($this->expression is BinaryOperatorExpression) { - // this handles "THEN 1" for example. when we hit token 1 we would have started a binary expression, - // but we just found the ELSE and we need to just return that constant expression instead - // also move the pointer back so parent can encounter this keyword - $this->pointer--; - return $this->expression->left; - } - // it's one of those things where I feel like I should have solved this before... - // otherwise we just found a keyword outside of a CASE statement, that seems wrong! - throw new SQLFakeParseException("Unexpected $operator"); - } - // tell the case statement we encountered a keyword so it knows where to stuff the next sub-expression - $this->expression->setKeyword(operator_to_string($operator)); - if ($operator !== 'END') { - // after WHEN, THEN, and ELSE there needs to be a well-formed expression that we have to parse - $this->pointer = $this->expression - ->addRecursiveExpression($this->tokens, $this->pointer); - } - - break; - } - - // when "EXISTS (foo)" - // TODO handle EXISTS - if ($this->expression->operator is nonnull) { - if ( - $operator === Operator::AND && - $this->expression->operator === Operator::BETWEEN && - !$this->expression->isWellFormed() - ) { - $this->expression as BetweenOperatorExpression; - $this->expression->foundAnd(); - } else if ($operator === Operator::NOT) { - if ($this->expression->operator !== Operator::IS) { - $next = $this->peekNext(); - if ( - $next !== null && - ( - ($next['type'] === TokenType::OPERATOR && Str\uppercase($next['value']) === Operator::IN) || - ($next['type'] === TokenType::OPERATOR && Str\uppercase($next['value']) === Operator::LIKE) || - $next['type'] === TokenType::PAREN - ) - ) { - // something like "A=B AND s NOT IN (foo)", we have to recurse to handle that also - // also AND NOT (foo AND bar) - $this->pointer = $this->expression - ->addRecursiveExpression($this->tokens, $this->pointer, true); - break; - } - throw new SQLFakeParseException('Unexpected NOT'); - } - $this->expression->negate(); - } else { - // If the new operator has higher precedence, we need to recurse so that it can end up inside the current expression (so it gets evaluated first) - // Otherwise, we take the entire current expression and nest it inside a new one, which we assume to be Binary for now - - $current_op_precedence = $this->expression->precedence; - $new_op_precedence = $this->getPrecedence(operator_to_string($operator)); - if ($current_op_precedence < $new_op_precedence) { - // example: 5 + 8 * 3 - // we are at the "*" right now, and have to move "8" out of the "right" from - // the + operator and into the "left" of the new "*' operator, which gets nested inside the + as its own "right" - $this->pointer = $this->expression - ->addRecursiveExpression($this->tokens, $this->pointer - 1); - } else { - // example: 9 / 3 * 3 - // We are at the *. Take the entire current expression, make it be the "left" of a new expression with a "type" of the current operator - // It's important to nest like this to preserve left-to-right evaluation. - if ($operator === Operator::BETWEEN) { - $this->expression = new BetweenOperatorExpression($this->expression); - } else if ($operator === Operator::IN) { - $this->expression = new InOperatorExpression($this->expression, $this->expression->negated); - } else { - $this->expression = new BinaryOperatorExpression($this->expression, false, $operator); - } - } - } - } else { - if ($operator === Operator::BETWEEN) { - if (!$this->expression is BinaryOperatorExpression) { - throw new SQLFakeParseException('Unexpected keyword BETWEEN'); - } - $this->expression = new BetweenOperatorExpression($this->expression->left); - } else if ($operator === Operator::NOT) { - // this negates another operator like "NOT IN" or "IS NOT NULL" - // operators would throw an SQLFakeException here if they don't support negation - $this->expression->negate(); - } else if ($operator === Operator::IN) { - if (!$this->expression is BinaryOperatorExpression) { - throw new SQLFakeParseException('Unexpected keyword IN'); - } - $this->expression = new InOperatorExpression($this->expression->left, $this->expression->negated); - } else if ( - $operator === Operator::UNARY_MINUS || $operator === Operator::UNARY_PLUS || $operator === Operator::TILDE - ) { - $this->expression as PlaceholderExpression; - $this->expression = new UnaryExpression($operator); - } else { - $this->expression as BinaryOperatorExpression; - $this->expression->setOperator($operator); - } - } - - break; - default: - throw new SQLFakeParseException("Expression parse error: unexpected {$token['value']}"); - } - - // don't move pointer forward and break early for some keywords - $nextToken = $this->peekNext(); - - if (!$nextToken) { - break; - } - - // return control to parent when we hit one of these - // special case: VALUES is both a clause and can also be a function inside an INSERT... ON DUPLICATE KEY UPDATE - // if we find a VALUES and the current expression is incomplete, we keep going - // TODO actually maybe we should just seek for the VALUES( sequence in the setParser and collapse those into a sqlFunction token?? - if (C\contains_key(keyset[TokenType::CLAUSE, TokenType::RESERVED, TokenType::SEPARATOR], $nextToken['type'])) { - if ($nextToken['value'] === 'VALUES' && !$this->expression->isWellFormed()) { - // change VALUES from a CLAUSE to SQLFUNCTION if it occurs inside an expression - $this->tokens[$this->pointer + 1]['type'] = TokenType::SQLFUNCTION; - } else { - break; - } - } - - // possibly break out of the loop depending on next token, operator precedence, child status - if ($this->expression->isWellFormed()) { - // alias for the expression? - if ($nextToken['type'] === TokenType::IDENTIFIER) { - break; - } - - // this happens when processing a sub-expression inside of a CASE statement - // when we encounter the next CASE keyword, and already have a well formed expression, break and let the parent handle it - if (C\contains_key(keyset['ELSE', 'THEN', 'END'], $nextToken['value'])) { - break; - } - - // the only other valid thing to come after a well_formed operation is another operand (other than the things we break on just above) - if ($nextToken['type'] !== TokenType::OPERATOR) { - throw new SQLFakeParseException("Unexpected token {$nextToken['value']}"); - } - - if ($this->is_child) { - $next_operator_precedence = $this->getPrecedence($nextToken['value']); - - // we are inside a recursive child and found a lower or same precedence operator?? - // then bail, the parent needs to take it from here - // what matters here is not the current operator's precedence, but the lowest precedence we have seen in this instance - // (from Precedence Climbing algorithm) - if ($next_operator_precedence <= $this->min_precedence) { - break; - } - } - } - - $token = $this->nextToken(); - } - - if (!$this->expression->isWellFormed()) { - // if we encountered some token like a column, constant, or subquery and we didn't find any more tokens than that, just return that token as the entire expression - if ($this->expression is BinaryOperatorExpression && $this->expression->operator === null) { - return $this->expression->left; - } - throw new SQLFakeParseException('Parse error, unexpected end of input'); - } - return $this->expression; - } - - public function buildWithPointer(): (int, Expression) { - $expr = $this->build(); - return tuple($this->pointer, $expr); - } - - private function nextToken(): ?token { - $this->pointer++; - return $this->tokens[$this->pointer] ?? null; - } - - private function peekNext(): ?token { - return $this->tokens[$this->pointer + 1] ?? null; - } - - private function getPrecedence(string $operator): int { - return self::OPERATOR_PRECEDENCE[$operator] ?? 0; - } - - /* - * When parsing expressions in certain places like the GROUP BY or HAVING clauses, it's possible for column references to refer to aliases defined in the SELECT list - * This function can be called before parsing expressions in those places, so that if a column reference is not found in the tables it can be found from the select list - */ - public function setSelectExpressions(vec $expressions): void { - $this->selectExpressions = $expressions; - } + } + + /** parses an expression that is inside a delimited list, such as function arguments or row expressions + * i.e.: [col1, col2, col3] + * return tuple because of a bool indicating whether the "DISTINCT" expression was found + */ + private function getListExpression(token_list $tokens): (bool, vec) { + $distinct = false; + + $pos = 0; + $token_count = C\count($tokens); + $needs_comma = false; + + $args = vec[]; + + while ($pos < $token_count) { + $arg = $tokens[$pos]; + + if ($arg['value'] === 'DISTINCT' || $arg['value'] === 'DISTINCTROW') { + $distinct = true; + $pos++; + // DISTINCT can pretend to be a function, unroll it here + if ($tokens[$pos]['type'] === TokenType::PAREN) { + $close = SQLParser::findMatchingParen($pos, $tokens); + $pos++; + $t = $tokens[$pos]; + if ($close - $pos !== 1) { + throw new SQLFakeParseException('Parse error near DISTINCT'); + } + $p = new ExpressionParser(vec[$t], -1); + $expr = $p->build(); + $args[] = $expr; + $pos += 2; + } + continue; + } + + if ($arg['value'] === ',') { + if ($needs_comma) { + $needs_comma = false; + $pos++; + continue; + } else { + throw new SQLFakeParseException('Unexpected comma in SQL query'); + } + } + $p = new ExpressionParser($tokens, $pos - 1); + list($pos, $expr) = $p->buildWithPointer(); + $args[] = $expr; + $pos++; + $needs_comma = true; + } + + return tuple($distinct, $args); + } + + // sometimes we identify a token is of a type that it can be immediately turned into an Expression. do that if so + public function tokenToExpression(token $token): Expression { + + switch ($token['type']) { + case TokenType::NUMERIC_CONSTANT: + case TokenType::BOOLEAN_CONSTANT: + case TokenType::STRING_CONSTANT: + case TokenType::NULL_CONSTANT: + return new ConstantExpression($token); + case TokenType::IDENTIFIER: + // if we are processing an expression in the GROUP BY or HAVING or ORDER BY, check the select list first + // this is because the select may define aliases we can use in these clauses + // i.e. SELECT something as foo ... GROUP BY foo + if ($this->selectExpressions is nonnull) { + foreach ($this->selectExpressions as $expr) { + if ($expr->name === $token['value']) { + return $expr; + } + } + } + return new ColumnExpression($token); + case TokenType::SQLFUNCTION: + // function token... just need to resolve the args + + // next token has to be a paren + $next = $this->nextToken() as nonnull; + // invariant not exception here because the parser wouldn't have seen it as a function without this + invariant($next['type'] === TokenType::PAREN, 'function is be followed by parentheses'); + $closing_paren_pointer = SQLParser::findMatchingParen($this->pointer, $this->tokens); + + // process tokens inside the function arguments + $arg_tokens = Vec\slice($this->tokens, $this->pointer + 1, $closing_paren_pointer - $this->pointer - 1); + list($distinct, $args) = $this->getListExpression($arg_tokens); + + // move the pointer forward to the end of the parentheses + $this->pointer = $closing_paren_pointer; + + $fn = Str\starts_with($token['value'], 'JSON_') + ? new JSONFunctionExpression($token, $args, $distinct) + : new FunctionExpression($token, $args, $distinct); + + return $fn; + default: + throw new SQLFakeNotImplementedException("Not implemented: {$token['value']}"); + } + } + + /** + * The main top level API of this class. Builds a nested, evaluatable expression + */ + public function build(): Expression { + $token = $this->nextToken(); + while ($token !== null) { + switch ($token['type']) { + case TokenType::PAREN: + + $close = SQLParser::findMatchingParen($this->pointer, $this->tokens); + $arg_tokens = Vec\slice($this->tokens, $this->pointer + 1, $close - $this->pointer - 1); + if (!C\count($arg_tokens)) { + throw new SQLFakeParseException('Empty parentheses found'); + } + $this->pointer = $close; + $expr = new PlaceholderExpression(); + + if ($arg_tokens[0]['value'] === 'SELECT') { + $subquery_sql = Vec\map($arg_tokens, $token ==> $token['value']) |> Str\join($$, ' '); + $parser = new SelectParser(0, $arg_tokens, $subquery_sql); + list($p, $select) = $parser->parse(); + $expr = new SubqueryExpression($select, ''); + } else if ($this->expression is InOperatorExpression) { + $pointer = -1; + $in_list = vec[]; + $token_count = C\count($arg_tokens); + while ($pointer < $token_count) { + $p = new ExpressionParser($arg_tokens, $pointer); + list($pointer, $expr) = $p->buildWithPointer(); + $in_list[] = $expr; + + if ($pointer + 1 >= $token_count) { + break; + } + + $pointer++; + $next = $arg_tokens[$pointer]; + if ($next['value'] !== ',') { + throw new SQLFakeParseException('Expected , in IN () list'); + } + } + $this->expression as InOperatorExpression; + $this->expression->setInList($in_list); + // we just called set_in_list so skip the setNextChild call + break; + } else { + // in this context, we could either have a sub-expression like + // (a = b AND c=d) OR a=d + // + // or we could have a LIST element, like (a, b, c) < (x, y, z) + // + + // look for a row expression like (col1, col2, col3) + $second_token = $arg_tokens[1] ?? null; + if ($second_token !== null && $second_token['type'] === TokenType::SEPARATOR) { + list($distinct, $elements) = $this->getListExpression($arg_tokens); + if ($distinct) { + throw new SQLFakeParseException('Unexpected DISTINCT in row expression'); + } + + $expr = new RowExpression($elements); + } else { + $p = new ExpressionParser($arg_tokens, -1); + $expr = $p->build(); + } + } + + if ($this->expression is PlaceholderExpression) { + $this->expression = new BinaryOperatorExpression($expr); + } else { + $this->expression->setNextChild($expr); + } + break; + case TokenType::NUMERIC_CONSTANT: + case TokenType::BOOLEAN_CONSTANT: + case TokenType::NULL_CONSTANT: + case TokenType::STRING_CONSTANT: + case TokenType::SQLFUNCTION: + case TokenType::IDENTIFIER: + + // these token types are all not operands and can be parsed on their own into an expression. + // do that, then let the current expression object figure out where to put it based on its state + $expr = $this->tokenToExpression($token); + + if ($this->expression is PlaceholderExpression) { + // we assume an expression will be a Binary Operator (most common) when we first encounter a token + // if it's something else like BETWEEN or IN, we convert it to that deliberately when encountering the operand + $this->expression = new BinaryOperatorExpression($expr); + } else if ( + $this->expression->operator === null && + $this->expression is BinaryOperatorExpression && + $token['type'] === TokenType::IDENTIFIER + ) { + // we encountered an identifier immediately after the left side of an expression... only hope is that this is an implicit alias like + // "SELECT 1 foo" + // if so, move the pointer back one and return + $this->pointer--; + return $this->expression->left; + } else { + $this->expression->setNextChild($expr); + } + break; + case TokenType::OPERATOR: + // mysql is case insensitive, but php is case sensitive so just uppercase operators for comparisons + $operator = $token['value'] as Operator; + + if ($operator === Operator::CASE) { + if (!($this->expression is PlaceholderExpression)) { + // we encountered a CASE statement inside another expression + // such as "column_name = CASE when ... end" + // so make a new instance to parse the inner CASE, we'll return back here after the END + // we move the pointer back by 1 so that the child class encounters the CASE again with a PlaceholderExpression + $this->pointer = $this->expression + ->addRecursiveExpression($this->tokens, $this->pointer - 1); + break; + } + $this->expression = new CaseOperatorExpression($token); + break; + } else if (C\contains_key(keyset['WHEN', 'THEN', 'ELSE', 'END'], $operator)) { + if (!($this->expression is CaseOperatorExpression)) { + if ($this->expression is BinaryOperatorExpression) { + // this handles "THEN 1" for example. when we hit token 1 we would have started a binary expression, + // but we just found the ELSE and we need to just return that constant expression instead + // also move the pointer back so parent can encounter this keyword + $this->pointer--; + return $this->expression->left; + } + // it's one of those things where I feel like I should have solved this before... + // otherwise we just found a keyword outside of a CASE statement, that seems wrong! + throw new SQLFakeParseException("Unexpected $operator"); + } + // tell the case statement we encountered a keyword so it knows where to stuff the next sub-expression + $this->expression->setKeyword(operator_to_string($operator)); + if ($operator !== 'END') { + // after WHEN, THEN, and ELSE there needs to be a well-formed expression that we have to parse + $this->pointer = $this->expression + ->addRecursiveExpression($this->tokens, $this->pointer); + } + + break; + } + + // when "EXISTS (foo)" + // TODO handle EXISTS + if ($this->expression->operator is nonnull) { + if ( + $operator === Operator::AND && + $this->expression->operator === Operator::BETWEEN && + !$this->expression->isWellFormed() + ) { + $this->expression as BetweenOperatorExpression; + $this->expression->foundAnd(); + } else if ($operator === Operator::NOT) { + if ($this->expression->operator !== Operator::IS) { + $next = $this->peekNext(); + if ( + $next !== null && + ( + ($next['type'] === TokenType::OPERATOR && Str\uppercase($next['value']) === Operator::IN) || + ($next['type'] === TokenType::OPERATOR && Str\uppercase($next['value']) === Operator::LIKE) || + $next['type'] === TokenType::PAREN + ) + ) { + // something like "A=B AND s NOT IN (foo)", we have to recurse to handle that also + // also AND NOT (foo AND bar) + $this->pointer = $this->expression + ->addRecursiveExpression($this->tokens, $this->pointer, true); + break; + } + throw new SQLFakeParseException('Unexpected NOT'); + } + $this->expression->negate(); + } else { + // If the new operator has higher precedence, we need to recurse so that it can end up inside the current expression (so it gets evaluated first) + // Otherwise, we take the entire current expression and nest it inside a new one, which we assume to be Binary for now + + $current_op_precedence = $this->expression->precedence; + $new_op_precedence = $this->getPrecedence(operator_to_string($operator)); + if ($current_op_precedence < $new_op_precedence) { + // example: 5 + 8 * 3 + // we are at the "*" right now, and have to move "8" out of the "right" from + // the + operator and into the "left" of the new "*' operator, which gets nested inside the + as its own "right" + $this->pointer = $this->expression + ->addRecursiveExpression($this->tokens, $this->pointer - 1); + } else { + // example: 9 / 3 * 3 + // We are at the *. Take the entire current expression, make it be the "left" of a new expression with a "type" of the current operator + // It's important to nest like this to preserve left-to-right evaluation. + if ($operator === Operator::BETWEEN) { + $this->expression = new BetweenOperatorExpression($this->expression); + } else if ($operator === Operator::IN) { + $this->expression = new InOperatorExpression($this->expression, $this->expression->negated); + } else { + $this->expression = new BinaryOperatorExpression($this->expression, false, $operator); + } + } + } + } else { + if ($operator === Operator::BETWEEN) { + if (!$this->expression is BinaryOperatorExpression) { + throw new SQLFakeParseException('Unexpected keyword BETWEEN'); + } + $this->expression = new BetweenOperatorExpression($this->expression->left); + } else if ($operator === Operator::NOT) { + // this negates another operator like "NOT IN" or "IS NOT NULL" + // operators would throw an SQLFakeException here if they don't support negation + $this->expression->negate(); + } else if ($operator === Operator::IN) { + if (!$this->expression is BinaryOperatorExpression) { + throw new SQLFakeParseException('Unexpected keyword IN'); + } + $this->expression = new InOperatorExpression($this->expression->left, $this->expression->negated); + } else if ( + $operator === Operator::UNARY_MINUS || $operator === Operator::UNARY_PLUS || $operator === Operator::TILDE + ) { + $this->expression as PlaceholderExpression; + $this->expression = new UnaryExpression($operator); + } else { + $this->expression as BinaryOperatorExpression; + $this->expression->setOperator($operator); + } + } + + break; + default: + throw new SQLFakeParseException("Expression parse error: unexpected {$token['value']}"); + } + + // don't move pointer forward and break early for some keywords + $nextToken = $this->peekNext(); + + if (!$nextToken) { + break; + } + + // return control to parent when we hit one of these + // special case: VALUES is both a clause and can also be a function inside an INSERT... ON DUPLICATE KEY UPDATE + // if we find a VALUES and the current expression is incomplete, we keep going + // TODO actually maybe we should just seek for the VALUES( sequence in the setParser and collapse those into a sqlFunction token?? + if (C\contains_key(keyset[TokenType::CLAUSE, TokenType::RESERVED, TokenType::SEPARATOR], $nextToken['type'])) { + if ($nextToken['value'] === 'VALUES' && !$this->expression->isWellFormed()) { + // change VALUES from a CLAUSE to SQLFUNCTION if it occurs inside an expression + $this->tokens[$this->pointer + 1]['type'] = TokenType::SQLFUNCTION; + } else { + break; + } + } + + // possibly break out of the loop depending on next token, operator precedence, child status + if ($this->expression->isWellFormed()) { + // alias for the expression? + if ($nextToken['type'] === TokenType::IDENTIFIER) { + break; + } + + // this happens when processing a sub-expression inside of a CASE statement + // when we encounter the next CASE keyword, and already have a well formed expression, break and let the parent handle it + if (C\contains_key(keyset['ELSE', 'THEN', 'END'], $nextToken['value'])) { + break; + } + + // the only other valid thing to come after a well_formed operation is another operand (other than the things we break on just above) + if ($nextToken['type'] !== TokenType::OPERATOR) { + throw new SQLFakeParseException("Unexpected token {$nextToken['value']}"); + } + + if ($this->is_child) { + $next_operator_precedence = $this->getPrecedence($nextToken['value']); + + // we are inside a recursive child and found a lower or same precedence operator?? + // then bail, the parent needs to take it from here + // what matters here is not the current operator's precedence, but the lowest precedence we have seen in this instance + // (from Precedence Climbing algorithm) + if ($next_operator_precedence <= $this->min_precedence) { + break; + } + } + } + + $token = $this->nextToken(); + } + + if (!$this->expression->isWellFormed()) { + // if we encountered some token like a column, constant, or subquery and we didn't find any more tokens than that, just return that token as the entire expression + if ($this->expression is BinaryOperatorExpression && $this->expression->operator === null) { + return $this->expression->left; + } + throw new SQLFakeParseException('Parse error, unexpected end of input'); + } + return $this->expression; + } + + public function buildWithPointer(): (int, Expression) { + $expr = $this->build(); + return tuple($this->pointer, $expr); + } + + private function nextToken(): ?token { + $this->pointer++; + return $this->tokens[$this->pointer] ?? null; + } + + private function peekNext(): ?token { + return $this->tokens[$this->pointer + 1] ?? null; + } + + private function getPrecedence(string $operator): int { + return self::OPERATOR_PRECEDENCE[$operator] ?? 0; + } + + /* + * When parsing expressions in certain places like the GROUP BY or HAVING clauses, it's possible for column references to refer to aliases defined in the SELECT list + * This function can be called before parsing expressions in those places, so that if a column reference is not found in the tables it can be found from the select list + */ + public function setSelectExpressions(vec $expressions): void { + $this->selectExpressions = $expressions; + } } diff --git a/src/Parser/FromParser.php b/src/Parser/FromParser.php index 0a12532..b66e6cd 100644 --- a/src/Parser/FromParser.php +++ b/src/Parser/FromParser.php @@ -6,323 +6,323 @@ final class FromParser { - public function __construct(private int $pointer, private token_list $tokens) {} - - public function parse(): (int, FromClause) { - - // if we got here, the first token had better be a SELECT - if ($this->tokens[$this->pointer]['value'] !== 'FROM') { - throw new SQLFakeParseException('Parser error: expected FROM'); - } - $from = new FromClause(); - $this->pointer++; - $count = C\count($this->tokens); - - while ($this->pointer < $count) { - $token = $this->tokens[$this->pointer]; - - switch ($token['type']) { - case TokenType::STRING_CONSTANT: - if (!$from->mostRecentHasAlias) { - $from->aliasRecentExpression((string)$token['value']); - $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); - } else { - throw new SQLFakeParseException("Unexpected string constant {$token['raw']}"); - } - break; - case TokenType::IDENTIFIER: - if (C\count($from->tables) === 0) { - $table = shape( - 'name' => $token['value'], - 'join_type' => JoinType::JOIN, - ); - $from->addTable($table); - $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); - } else if (!$from->mostRecentHasAlias) { - $from->aliasRecentExpression((string)$token['value']); - $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); - } - break; - case TokenType::SEPARATOR: - if ($token['value'] === ',') { - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null) { - throw new SQLFakeParseException('Expected token after ,'); - } - $table = $this->getTableOrSubquery($next); - $table['join_type'] = JoinType::CROSS; - $from->addTable($table); - $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); - } else { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - break; - case TokenType::CLAUSE: - // we made it to the next clause, time to return - // let parent process this keyword - return tuple($this->pointer - 1, $from); - case TokenType::RESERVED: - switch ($token['value']) { - case 'AS': - // seek forward for identifier, then add alias to most recent expression - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null || $next['type'] !== TokenType::IDENTIFIER) { - throw new SQLFakeParseException('Expected identifer after AS'); - } - $from->aliasRecentExpression($next['value']); - $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); - break; - case 'JOIN': - case 'INNER': - case 'LEFT': - case 'RIGHT': - case 'STRAIGHT_JOIN': - case 'NATURAL': - case 'CROSS': - $last = C\last($from->tables); - if ($last === null) { - throw new SQLFakeParseException('Parser error: unexpected join keyword'); - } - $join = $this->buildJoin($last['name'], $token); - $from->addTable($join); - break; - default: - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - break; - case TokenType::PAREN: - // this must be a subquery in the FROM clause - $subquery = $this->getTableOrSubquery($token); - $from->addTable($subquery); - break; - default: - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - - $this->pointer++; - } - - return tuple($this->pointer, $from); - } - - private function getTableOrSubquery(token $token): from_table { - switch ($token['type']) { - case TokenType::IDENTIFIER: - return shape('name' => $token['value'], 'join_type' => JoinType::JOIN); - case TokenType::PAREN: - // this must be a subquery in the FROM clause - $close = SQLParser::findMatchingParen($this->pointer, $this->tokens); - $subquery_tokens = Vec\slice($this->tokens, $this->pointer + 1, $close - $this->pointer - 1); - if (!C\count($subquery_tokens)) { - throw new SQLFakeParseException('Empty parentheses found'); - } - $expr = new PlaceholderExpression(); - - // this will throw if the first keyword isn't SELECT which is what we want - $subquery_sql = Vec\map($subquery_tokens, $token ==> $token['value']) |> Str\join($$, ' '); - $parser = new SelectParser(0, $subquery_tokens, $subquery_sql); - list($p, $select) = $parser->parse(); - $expr = new SubqueryExpression($select, ''); - // only move pointer forward by $p - $this->pointer += $p + 1; - $next = $this->tokens[$this->pointer] ?? null; - - // we still have something left here after parsing the whole top level query? hopefully it's a multi-query keyword - while ($next !== null && C\contains_key(keyset['UNION', 'INTERSECT', 'EXCEPT'], $next['value'])) { - $type = $next['value']; - if ($next['value'] === 'UNION') { - $next_plus = $this->tokens[$this->pointer + 1]; - if ($next_plus['value'] === 'ALL') { - $type = 'UNION_ALL'; - $this->pointer++; - } - if ($next_plus['value'] === 'DISTINCT') { - $this->pointer++; - } - } - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - $subselect = new SelectParser($this->pointer, $this->tokens, ''); - list($p, $q) = $subselect->parse(); - $this->pointer += $p; - $select->addMultiQuery(MultiOperand::assert($type), $q); - - $next = $this->tokens[$this->pointer] ?? null; - } - - $this->pointer = $close + 1; - $next = $this->tokens[$this->pointer] ?? null; - - if ($next !== null && $next['value'] === 'AS') { - $this->pointer++; - $next = $this->tokens[$this->pointer]; - } - if ($next === null || $next['type'] !== TokenType::IDENTIFIER) { - throw new SQLFakeParseException('Every subquery must have an alias'); - } - $name = $next['value']; - - $table = shape( - 'name' => $name, - 'subquery' => $expr, - 'join_type' => JoinType::JOIN, - 'alias' => $name, - ); - return $table; - default: - throw new SQLFakeParseException('Expected table name or subquery'); - } - - } - - /* - * Seek as far forward as is needed to build the JOIN expression including join type, conditions, table name and alias - */ - private function buildJoin(string $left_table, token $token): from_table { - - // INNER JOIN and JOIN are aliases - $join_type = $token['value']; - if ($token['value'] === 'INNER') { - $join_type = 'JOIN'; - } - - if (C\contains_key(keyset['INNER', 'CROSS', 'NATURAL'], $token['value'])) { - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null || $next['value'] !== 'JOIN') { - throw new SQLFakeParseException("Expected keyword JOIN after {$token['value']}"); - } - } else if (C\contains_key(keyset['LEFT', 'RIGHT'], $token['value'])) { - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next !== null && $next['value'] === 'OUTER') { - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - } - if ($next === null || $next['value'] !== 'JOIN') { - throw new SQLFakeParseException("Expected keyword JOIN after {$token['value']}"); - } - } - - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null) { - throw new SQLFakeParseException('Expected table or subquery after join keyword'); - } - $table = $this->getTableOrSubquery($next); - $table['join_type'] = JoinType::assert($join_type); - - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - // it is possible to end the entire query here like "SELECT * from foo join bar" - if ($next === null) { - return $table; - } - - // maybe add alias - if ($next['type'] === TokenType::IDENTIFIER) { - $table['alias'] = $next['value']; - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null) { - return $table; - } - } else if ($next['value'] === 'AS') { - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null || $next['type'] !== TokenType::IDENTIFIER) { - throw new SQLFakeParseException('Expected identifier after AS'); - } - $table['alias'] = $next['value']; - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null) { - return $table; - } - } - - $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); - - // this has to end here if it's a natural or cross join - // NOTE the natural join filter has to be built later, at runtime - if (C\contains_key(keyset[JoinType::NATURAL, JoinType::CROSS], $table['join_type'])) { - return $table; - } - - // now we need ON or USING - if ($next['type'] !== TokenType::RESERVED || !C\contains_key(keyset['ON', 'USING'], $next['value'])) { - throw new SQLFakeParseException('Expected ON or USING join condition'); - } - - if ($next['value'] === 'USING') { - $table['join_operator'] = JoinOperator::USING; - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null || $next['type'] !== TokenType::PAREN) { - throw new SQLFakeParseException('Expected ( after USING clause'); - } - $closing_paren_pointer = SQLParser::findMatchingParen($this->pointer, $this->tokens); - $arg_tokens = Vec\slice($this->tokens, $this->pointer + 1, $closing_paren_pointer - $this->pointer - 1); - if (!C\count($arg_tokens)) { - throw new SQLFakeParseException('Expected at least one argument to USING() clause'); - } - $count = 0; - $filter = null; - foreach ($arg_tokens as $arg) { - $count++; - if ($count % 2 === 1) { - // odd arguments should be columns - if ($arg['type'] !== TokenType::IDENTIFIER) { - throw new SQLFakeParseException('Expected identifier in USING clause'); - } - $filter = $this->addJoinFilterExpression($filter, $left_table, $table['name'], $arg['value']); - } else if ($arg['value'] !== ',') { - throw new SQLFakeParseException('Expected , after argument in USING clause'); - } - } - - $this->pointer = $closing_paren_pointer + 1; - $table['join_expression'] = $filter; - } else { - $table['join_operator'] = JoinOperator::ON; - $expression_parser = new ExpressionParser($this->tokens, $this->pointer); - list($this->pointer, $expression) = $expression_parser->buildWithPointer(); - $table['join_expression'] = $expression; - } - - return $table; - } - - public function addJoinFilterExpression( - ?Expression $filter, - string $left_table, - string $right_table, - string $column, - ): BinaryOperatorExpression { - - $left = new ColumnExpression(shape( - 'type' => TokenType::IDENTIFIER, - 'value' => "{$left_table}.{$column}", - 'raw' => '', - )); - $right = new ColumnExpression(shape( - 'type' => TokenType::IDENTIFIER, - 'value' => "{$right_table}.{$column}", - 'raw' => '', - )); - - // making a binary expression ensuring those two tokens are equal - $expr = new BinaryOperatorExpression($left, /* $negated */ false, Operator::EQUALS, $right); - - // if this is not the first condition, make an AND that wraps the current and new filter - if ($filter !== null) { - $filter = new BinaryOperatorExpression($filter, /* $negated */ false, Operator::AND, $expr); - } else { - $filter = $expr; - } - - return $filter; - } + public function __construct(private int $pointer, private token_list $tokens) {} + + public function parse(): (int, FromClause) { + + // if we got here, the first token had better be a SELECT + if ($this->tokens[$this->pointer]['value'] !== 'FROM') { + throw new SQLFakeParseException('Parser error: expected FROM'); + } + $from = new FromClause(); + $this->pointer++; + $count = C\count($this->tokens); + + while ($this->pointer < $count) { + $token = $this->tokens[$this->pointer]; + + switch ($token['type']) { + case TokenType::STRING_CONSTANT: + if (!$from->mostRecentHasAlias) { + $from->aliasRecentExpression((string)$token['value']); + $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); + } else { + throw new SQLFakeParseException("Unexpected string constant {$token['raw']}"); + } + break; + case TokenType::IDENTIFIER: + if (C\count($from->tables) === 0) { + $table = shape( + 'name' => $token['value'], + 'join_type' => JoinType::JOIN, + ); + $from->addTable($table); + $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); + } else if (!$from->mostRecentHasAlias) { + $from->aliasRecentExpression((string)$token['value']); + $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); + } + break; + case TokenType::SEPARATOR: + if ($token['value'] === ',') { + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null) { + throw new SQLFakeParseException('Expected token after ,'); + } + $table = $this->getTableOrSubquery($next); + $table['join_type'] = JoinType::CROSS; + $from->addTable($table); + $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); + } else { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + break; + case TokenType::CLAUSE: + // we made it to the next clause, time to return + // let parent process this keyword + return tuple($this->pointer - 1, $from); + case TokenType::RESERVED: + switch ($token['value']) { + case 'AS': + // seek forward for identifier, then add alias to most recent expression + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null || $next['type'] !== TokenType::IDENTIFIER) { + throw new SQLFakeParseException('Expected identifer after AS'); + } + $from->aliasRecentExpression($next['value']); + $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); + break; + case 'JOIN': + case 'INNER': + case 'LEFT': + case 'RIGHT': + case 'STRAIGHT_JOIN': + case 'NATURAL': + case 'CROSS': + $last = C\last($from->tables); + if ($last === null) { + throw new SQLFakeParseException('Parser error: unexpected join keyword'); + } + $join = $this->buildJoin($last['name'], $token); + $from->addTable($join); + break; + default: + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + break; + case TokenType::PAREN: + // this must be a subquery in the FROM clause + $subquery = $this->getTableOrSubquery($token); + $from->addTable($subquery); + break; + default: + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + + $this->pointer++; + } + + return tuple($this->pointer, $from); + } + + private function getTableOrSubquery(token $token): from_table { + switch ($token['type']) { + case TokenType::IDENTIFIER: + return shape('name' => $token['value'], 'join_type' => JoinType::JOIN); + case TokenType::PAREN: + // this must be a subquery in the FROM clause + $close = SQLParser::findMatchingParen($this->pointer, $this->tokens); + $subquery_tokens = Vec\slice($this->tokens, $this->pointer + 1, $close - $this->pointer - 1); + if (!C\count($subquery_tokens)) { + throw new SQLFakeParseException('Empty parentheses found'); + } + $expr = new PlaceholderExpression(); + + // this will throw if the first keyword isn't SELECT which is what we want + $subquery_sql = Vec\map($subquery_tokens, $token ==> $token['value']) |> Str\join($$, ' '); + $parser = new SelectParser(0, $subquery_tokens, $subquery_sql); + list($p, $select) = $parser->parse(); + $expr = new SubqueryExpression($select, ''); + // only move pointer forward by $p + $this->pointer += $p + 1; + $next = $this->tokens[$this->pointer] ?? null; + + // we still have something left here after parsing the whole top level query? hopefully it's a multi-query keyword + while ($next !== null && C\contains_key(keyset['UNION', 'INTERSECT', 'EXCEPT'], $next['value'])) { + $type = $next['value']; + if ($next['value'] === 'UNION') { + $next_plus = $this->tokens[$this->pointer + 1]; + if ($next_plus['value'] === 'ALL') { + $type = 'UNION_ALL'; + $this->pointer++; + } + if ($next_plus['value'] === 'DISTINCT') { + $this->pointer++; + } + } + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + $subselect = new SelectParser($this->pointer, $this->tokens, ''); + list($p, $q) = $subselect->parse(); + $this->pointer += $p; + $select->addMultiQuery(MultiOperand::assert($type), $q); + + $next = $this->tokens[$this->pointer] ?? null; + } + + $this->pointer = $close + 1; + $next = $this->tokens[$this->pointer] ?? null; + + if ($next !== null && $next['value'] === 'AS') { + $this->pointer++; + $next = $this->tokens[$this->pointer]; + } + if ($next === null || $next['type'] !== TokenType::IDENTIFIER) { + throw new SQLFakeParseException('Every subquery must have an alias'); + } + $name = $next['value']; + + $table = shape( + 'name' => $name, + 'subquery' => $expr, + 'join_type' => JoinType::JOIN, + 'alias' => $name, + ); + return $table; + default: + throw new SQLFakeParseException('Expected table name or subquery'); + } + + } + + /* + * Seek as far forward as is needed to build the JOIN expression including join type, conditions, table name and alias + */ + private function buildJoin(string $left_table, token $token): from_table { + + // INNER JOIN and JOIN are aliases + $join_type = $token['value']; + if ($token['value'] === 'INNER') { + $join_type = 'JOIN'; + } + + if (C\contains_key(keyset['INNER', 'CROSS', 'NATURAL'], $token['value'])) { + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null || $next['value'] !== 'JOIN') { + throw new SQLFakeParseException("Expected keyword JOIN after {$token['value']}"); + } + } else if (C\contains_key(keyset['LEFT', 'RIGHT'], $token['value'])) { + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next !== null && $next['value'] === 'OUTER') { + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + } + if ($next === null || $next['value'] !== 'JOIN') { + throw new SQLFakeParseException("Expected keyword JOIN after {$token['value']}"); + } + } + + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null) { + throw new SQLFakeParseException('Expected table or subquery after join keyword'); + } + $table = $this->getTableOrSubquery($next); + $table['join_type'] = JoinType::assert($join_type); + + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + // it is possible to end the entire query here like "SELECT * from foo join bar" + if ($next === null) { + return $table; + } + + // maybe add alias + if ($next['type'] === TokenType::IDENTIFIER) { + $table['alias'] = $next['value']; + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null) { + return $table; + } + } else if ($next['value'] === 'AS') { + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null || $next['type'] !== TokenType::IDENTIFIER) { + throw new SQLFakeParseException('Expected identifier after AS'); + } + $table['alias'] = $next['value']; + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null) { + return $table; + } + } + + $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); + + // this has to end here if it's a natural or cross join + // NOTE the natural join filter has to be built later, at runtime + if (C\contains_key(keyset[JoinType::NATURAL, JoinType::CROSS], $table['join_type'])) { + return $table; + } + + // now we need ON or USING + if ($next['type'] !== TokenType::RESERVED || !C\contains_key(keyset['ON', 'USING'], $next['value'])) { + throw new SQLFakeParseException('Expected ON or USING join condition'); + } + + if ($next['value'] === 'USING') { + $table['join_operator'] = JoinOperator::USING; + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null || $next['type'] !== TokenType::PAREN) { + throw new SQLFakeParseException('Expected ( after USING clause'); + } + $closing_paren_pointer = SQLParser::findMatchingParen($this->pointer, $this->tokens); + $arg_tokens = Vec\slice($this->tokens, $this->pointer + 1, $closing_paren_pointer - $this->pointer - 1); + if (!C\count($arg_tokens)) { + throw new SQLFakeParseException('Expected at least one argument to USING() clause'); + } + $count = 0; + $filter = null; + foreach ($arg_tokens as $arg) { + $count++; + if ($count % 2 === 1) { + // odd arguments should be columns + if ($arg['type'] !== TokenType::IDENTIFIER) { + throw new SQLFakeParseException('Expected identifier in USING clause'); + } + $filter = $this->addJoinFilterExpression($filter, $left_table, $table['name'], $arg['value']); + } else if ($arg['value'] !== ',') { + throw new SQLFakeParseException('Expected , after argument in USING clause'); + } + } + + $this->pointer = $closing_paren_pointer + 1; + $table['join_expression'] = $filter; + } else { + $table['join_operator'] = JoinOperator::ON; + $expression_parser = new ExpressionParser($this->tokens, $this->pointer); + list($this->pointer, $expression) = $expression_parser->buildWithPointer(); + $table['join_expression'] = $expression; + } + + return $table; + } + + public function addJoinFilterExpression( + ?Expression $filter, + string $left_table, + string $right_table, + string $column, + ): BinaryOperatorExpression { + + $left = new ColumnExpression(shape( + 'type' => TokenType::IDENTIFIER, + 'value' => "{$left_table}.{$column}", + 'raw' => '', + )); + $right = new ColumnExpression(shape( + 'type' => TokenType::IDENTIFIER, + 'value' => "{$right_table}.{$column}", + 'raw' => '', + )); + + // making a binary expression ensuring those two tokens are equal + $expr = new BinaryOperatorExpression($left, /* $negated */ false, Operator::EQUALS, $right); + + // if this is not the first condition, make an AND that wraps the current and new filter + if ($filter !== null) { + $filter = new BinaryOperatorExpression($filter, /* $negated */ false, Operator::AND, $expr); + } else { + $filter = $expr; + } + + return $filter; + } } diff --git a/src/Parser/InsertParser.php b/src/Parser/InsertParser.php index 7f2e2b5..c728702 100644 --- a/src/Parser/InsertParser.php +++ b/src/Parser/InsertParser.php @@ -12,225 +12,225 @@ */ final class InsertParser { - const dict CLAUSE_ORDER = dict[ - 'INSERT' => 1, - 'COLUMN_LIST' => 2, - 'VALUES' => 3, - 'ON' => 4, - 'SET' => 5, - ]; - - private string $currentClause = 'INSERT'; - private int $pointer = 0; - - public function __construct(private token_list $tokens, private string $sql) {} - - public function parse(): InsertQuery { - - // if we got here, the first token had better be a INSERT - if ($this->tokens[$this->pointer]['value'] !== 'INSERT') { - throw new SQLFakeParseException('Parser error: expected INSERT'); - } - $this->pointer++; - - // ignore these keywords which can come after INSERT - if (C\contains_key(keyset['LOW_PRIORITY', 'DELAYED', 'HIGH_PRIORITY'], $this->tokens[$this->pointer]['value'])) { - $this->pointer++; - } - - // IGNORE can come next and indicates duplicate keys should be ignored - $ignore_dupes = false; - if ($this->tokens[$this->pointer]['value'] === 'IGNORE') { - $ignore_dupes = true; - $this->pointer++; - } - - // INTO is optional and has no effect. skip it if present - if ($this->tokens[$this->pointer]['value'] === 'INTO') { - $this->pointer++; - } - - // next token has to be a table name - $token = $this->tokens[$this->pointer]; - if ($token === null || $token['type'] !== TokenType::IDENTIFIER) { - throw new SQLFakeParseException('Expected table name after INSERT'); - } - $this->pointer++; - - $query = new InsertQuery($token['value'], $this->sql, $ignore_dupes); - - $count = C\count($this->tokens); - - $needs_comma = false; - while ($this->pointer < $count) { - $token = $this->tokens[$this->pointer]; - - // handle VALUES() as a function, like "on duplicate key update column=VALUES(column)" - if ($this->currentClause === 'SET' && $token['value'] === 'VALUES') { - $token['type'] = TokenType::SQLFUNCTION; - } - - switch ($token['type']) { - case TokenType::CLAUSE: - // make sure clauses are in order - if ( - C\contains_key(self::CLAUSE_ORDER, $token['value']) && - self::CLAUSE_ORDER[$this->currentClause] >= self::CLAUSE_ORDER[$token['value']] - ) { - throw new SQLFakeParseException("Unexpected clause {$token['value']}"); - } - - switch ($token['value']) { - case 'VALUES': - $needs_another_plus_plus = false; - do { - $this->pointer++; - if ($needs_another_plus_plus) { - $this->pointer++; - $needs_another_plus_plus = false; - } - $token = $this->tokens[$this->pointer]; - // VALUES must be followed by paren and then a list of values - if ($token === null || $token['value'] !== '(') { - throw new SQLFakeParseException('Expected ( after VALUES'); - } - $close = SQLParser::findMatchingParen($this->pointer, $this->tokens); - $values_tokens = Vec\slice($this->tokens, $this->pointer + 1, $close - $this->pointer - 1); - $values = $this->parseValues($values_tokens); - if (C\count($values) !== C\count($query->insertColumns)) { - throw new SQLFakeParseException( - 'Insert list contains '. - C\count($query->insertColumns). - ' fields, but values clause contains '. - C\count($values), - ); - } - $query->values[] = $values; - $this->pointer = $close; - $needs_another_plus_plus = true; - } while (($this->tokens[$this->pointer + 1]['value'] ?? null) === ',' && $this->pointer); - // The while loop above used to havea $this->pointer++ here. ^^^^^^^^^^^^^^ - // We still need to increment is this condition is true. - if ($needs_another_plus_plus && ($this->tokens[$this->pointer + 1]['value'] ?? null) === ',') { - $this->pointer++; - } - break; - default: - throw new SQLFakeParseException("Unexpected clause {$token['value']}"); - } - $this->currentClause = $token['value']; - break; - case TokenType::IDENTIFIER: - if ($needs_comma) { - throw new SQLFakeParseException('Expected , between expressions in INSERT'); - } - - if ($this->currentClause !== 'COLUMN_LIST') { - throw new SQLFakeParseException("Unexpected token {$token['value']} in INSERT"); - } - - $query->insertColumns[] = $token['value']; - $needs_comma = true; - break; - case TokenType::PAREN: - // are we opening the insert list? - if ($this->currentClause === 'INSERT' && $token['value'] === '(') { - $this->currentClause = 'COLUMN_LIST'; - break; - } - - throw new SQLFakeParseException('Unexpected ('); - case TokenType::SEPARATOR: - if ($token['value'] === ',') { - if (!$needs_comma) { - throw new SQLFakeParseException('Unexpected ,'); - } - $needs_comma = false; - } else if ($this->currentClause === 'COLUMN_LIST' && $needs_comma && $token['value'] === ')') { - // closing the insert column list? - $needs_comma = false; - if (($this->tokens[$this->pointer + 1]['value'] ?? null) !== 'VALUES') { - throw new SQLFakeParseException('Expected VALUES after insert column list'); - } - break; - } else { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - break; - case TokenType::RESERVED: - if ($token['value'] === 'ON') { - $expected = vec['DUPLICATE', 'KEY', 'UPDATE']; - $next_pointer = $this->pointer + 1; - foreach ($expected as $index => $keyword) { - $next = $this->tokens[$next_pointer + $index] ?? null; - if ($next === null || $next['value'] !== $keyword) { - throw new SQLFakeParseException('Unexpected keyword near ON'); - } - } - $this->pointer += 3; - $p = new SetParser($this->pointer, $this->tokens); - list($this->pointer, $query->updateExpressions) = $p->parse(/* $skip_set */ true); - break; - } - throw new SQLFakeParseException("Unexpected {$token['value']}"); - default: - throw new SQLFakeParseException("Unexpected token {$token['value']}"); - } - - $this->pointer++; - } - - if (C\is_empty($query->insertColumns) || C\is_empty($query->values)) { - throw new SQLFakeParseException('Missing values to insert'); - } - - return $query; - } - - /** - * Parse a VALUES clause into a list of expressions - */ - protected function parseValues(vec $tokens): vec { - $pointer = 0; - $count = C\count($tokens); - $expressions = vec[]; - - $needs_comma = false; - while ($pointer < $count) { - $token = $tokens[$pointer]; - switch ($token['type']) { - case TokenType::IDENTIFIER: - case TokenType::NUMERIC_CONSTANT: - case TokenType::BOOLEAN_CONSTANT: - case TokenType::STRING_CONSTANT: - case TokenType::NULL_CONSTANT: - case TokenType::OPERATOR: - case TokenType::SQLFUNCTION: - case TokenType::PAREN: - if ($needs_comma) { - throw new SQLFakeParseException("Expected , between expressions in SET clause near {$token['value']}"); - } - $expression_parser = new ExpressionParser($tokens, $pointer - 1); - list($pointer, $expression) = $expression_parser->buildWithPointer(); - $expressions[] = $expression; - $needs_comma = true; - break; - case TokenType::SEPARATOR: - if ($token['value'] === ',') { - if (!$needs_comma) { - echo 'le comma one'; - throw new SQLFakeParseException('Unexpected ,'); - } - $needs_comma = false; - } else { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - break; - default: - throw new SQLFakeParseException("Unexpected token {$token['value']}"); - } - $pointer++; - } - return $expressions; - } + const dict CLAUSE_ORDER = dict[ + 'INSERT' => 1, + 'COLUMN_LIST' => 2, + 'VALUES' => 3, + 'ON' => 4, + 'SET' => 5, + ]; + + private string $currentClause = 'INSERT'; + private int $pointer = 0; + + public function __construct(private token_list $tokens, private string $sql) {} + + public function parse(): InsertQuery { + + // if we got here, the first token had better be a INSERT + if ($this->tokens[$this->pointer]['value'] !== 'INSERT') { + throw new SQLFakeParseException('Parser error: expected INSERT'); + } + $this->pointer++; + + // ignore these keywords which can come after INSERT + if (C\contains_key(keyset['LOW_PRIORITY', 'DELAYED', 'HIGH_PRIORITY'], $this->tokens[$this->pointer]['value'])) { + $this->pointer++; + } + + // IGNORE can come next and indicates duplicate keys should be ignored + $ignore_dupes = false; + if ($this->tokens[$this->pointer]['value'] === 'IGNORE') { + $ignore_dupes = true; + $this->pointer++; + } + + // INTO is optional and has no effect. skip it if present + if ($this->tokens[$this->pointer]['value'] === 'INTO') { + $this->pointer++; + } + + // next token has to be a table name + $token = $this->tokens[$this->pointer]; + if ($token === null || $token['type'] !== TokenType::IDENTIFIER) { + throw new SQLFakeParseException('Expected table name after INSERT'); + } + $this->pointer++; + + $query = new InsertQuery($token['value'], $this->sql, $ignore_dupes); + + $count = C\count($this->tokens); + + $needs_comma = false; + while ($this->pointer < $count) { + $token = $this->tokens[$this->pointer]; + + // handle VALUES() as a function, like "on duplicate key update column=VALUES(column)" + if ($this->currentClause === 'SET' && $token['value'] === 'VALUES') { + $token['type'] = TokenType::SQLFUNCTION; + } + + switch ($token['type']) { + case TokenType::CLAUSE: + // make sure clauses are in order + if ( + C\contains_key(self::CLAUSE_ORDER, $token['value']) && + self::CLAUSE_ORDER[$this->currentClause] >= self::CLAUSE_ORDER[$token['value']] + ) { + throw new SQLFakeParseException("Unexpected clause {$token['value']}"); + } + + switch ($token['value']) { + case 'VALUES': + $needs_another_plus_plus = false; + do { + $this->pointer++; + if ($needs_another_plus_plus) { + $this->pointer++; + $needs_another_plus_plus = false; + } + $token = $this->tokens[$this->pointer]; + // VALUES must be followed by paren and then a list of values + if ($token === null || $token['value'] !== '(') { + throw new SQLFakeParseException('Expected ( after VALUES'); + } + $close = SQLParser::findMatchingParen($this->pointer, $this->tokens); + $values_tokens = Vec\slice($this->tokens, $this->pointer + 1, $close - $this->pointer - 1); + $values = $this->parseValues($values_tokens); + if (C\count($values) !== C\count($query->insertColumns)) { + throw new SQLFakeParseException( + 'Insert list contains '. + C\count($query->insertColumns). + ' fields, but values clause contains '. + C\count($values), + ); + } + $query->values[] = $values; + $this->pointer = $close; + $needs_another_plus_plus = true; + } while (($this->tokens[$this->pointer + 1]['value'] ?? null) === ',' && $this->pointer); + // The while loop above used to havea $this->pointer++ here. ^^^^^^^^^^^^^^ + // We still need to increment is this condition is true. + if ($needs_another_plus_plus && ($this->tokens[$this->pointer + 1]['value'] ?? null) === ',') { + $this->pointer++; + } + break; + default: + throw new SQLFakeParseException("Unexpected clause {$token['value']}"); + } + $this->currentClause = $token['value']; + break; + case TokenType::IDENTIFIER: + if ($needs_comma) { + throw new SQLFakeParseException('Expected , between expressions in INSERT'); + } + + if ($this->currentClause !== 'COLUMN_LIST') { + throw new SQLFakeParseException("Unexpected token {$token['value']} in INSERT"); + } + + $query->insertColumns[] = $token['value']; + $needs_comma = true; + break; + case TokenType::PAREN: + // are we opening the insert list? + if ($this->currentClause === 'INSERT' && $token['value'] === '(') { + $this->currentClause = 'COLUMN_LIST'; + break; + } + + throw new SQLFakeParseException('Unexpected ('); + case TokenType::SEPARATOR: + if ($token['value'] === ',') { + if (!$needs_comma) { + throw new SQLFakeParseException('Unexpected ,'); + } + $needs_comma = false; + } else if ($this->currentClause === 'COLUMN_LIST' && $needs_comma && $token['value'] === ')') { + // closing the insert column list? + $needs_comma = false; + if (($this->tokens[$this->pointer + 1]['value'] ?? null) !== 'VALUES') { + throw new SQLFakeParseException('Expected VALUES after insert column list'); + } + break; + } else { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + break; + case TokenType::RESERVED: + if ($token['value'] === 'ON') { + $expected = vec['DUPLICATE', 'KEY', 'UPDATE']; + $next_pointer = $this->pointer + 1; + foreach ($expected as $index => $keyword) { + $next = $this->tokens[$next_pointer + $index] ?? null; + if ($next === null || $next['value'] !== $keyword) { + throw new SQLFakeParseException('Unexpected keyword near ON'); + } + } + $this->pointer += 3; + $p = new SetParser($this->pointer, $this->tokens); + list($this->pointer, $query->updateExpressions) = $p->parse(/* $skip_set */ true); + break; + } + throw new SQLFakeParseException("Unexpected {$token['value']}"); + default: + throw new SQLFakeParseException("Unexpected token {$token['value']}"); + } + + $this->pointer++; + } + + if (C\is_empty($query->insertColumns) || C\is_empty($query->values)) { + throw new SQLFakeParseException('Missing values to insert'); + } + + return $query; + } + + /** + * Parse a VALUES clause into a list of expressions + */ + protected function parseValues(vec $tokens): vec { + $pointer = 0; + $count = C\count($tokens); + $expressions = vec[]; + + $needs_comma = false; + while ($pointer < $count) { + $token = $tokens[$pointer]; + switch ($token['type']) { + case TokenType::IDENTIFIER: + case TokenType::NUMERIC_CONSTANT: + case TokenType::BOOLEAN_CONSTANT: + case TokenType::STRING_CONSTANT: + case TokenType::NULL_CONSTANT: + case TokenType::OPERATOR: + case TokenType::SQLFUNCTION: + case TokenType::PAREN: + if ($needs_comma) { + throw new SQLFakeParseException("Expected , between expressions in SET clause near {$token['value']}"); + } + $expression_parser = new ExpressionParser($tokens, $pointer - 1); + list($pointer, $expression) = $expression_parser->buildWithPointer(); + $expressions[] = $expression; + $needs_comma = true; + break; + case TokenType::SEPARATOR: + if ($token['value'] === ',') { + if (!$needs_comma) { + echo 'le comma one'; + throw new SQLFakeParseException('Unexpected ,'); + } + $needs_comma = false; + } else { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + break; + default: + throw new SQLFakeParseException("Unexpected token {$token['value']}"); + } + $pointer++; + } + return $expressions; + } } diff --git a/src/Parser/Lexer.php b/src/Parser/Lexer.php index 94878ee..b3186fb 100644 --- a/src/Parser/Lexer.php +++ b/src/Parser/Lexer.php @@ -11,172 +11,172 @@ */ final class SQLLexer { - // this regex contains all MySQL operands and words patterns (including whitespace) that can separate SQL Tokens - // the longer ones come first which is important because of how preg_split works - private static string $token_split_regex = - '/(\<\=\>|\r\n|\!\=|\>\=|\<\=|\<\>|\<\<|\>\>|\:\=|&&|\|\||\:\=|\/\*|\*\/|\-\-|\>|\<|\||\=|\^|\(|\)|\t|\n|\'|"|`|,|@|\s|\+|\-|\*|\/|;|\\\)/'; - - public function lex(string $sql): vec { - - $tokens = preg_split(self::$token_split_regex, $sql, null, \PREG_SPLIT_DELIM_CAPTURE | \PREG_SPLIT_NO_EMPTY) - |> dict($$); - - // first process SQL comments, grouping them into one token so that the rest of the statements don't need to worry about handling comments - // since comments are rare, save a bit of time by skipping this if the SQL can't possibly contain comments - if (preg_match('![/#-]!', $sql)) { - $tokens = $this->groupComments($tokens); - } - - // for a string like 'Scott\\'s' this puts the ' character in the right place, necessary for balance_quotes to works - // since backslashes are rare, save time by skipping this if backslashes aren't present - if (Str\contains($sql, '\\')) { - $tokens = $this->groupEscapeSequences($tokens); - } - - // quotes are common, always do this - return $this->groupQuotedTokens($tokens); - } - - // group comments into a single token, then remove them - private function groupComments(dict $tokens): dict { - - $comment = null; - $inline = false; - $count = C\count($tokens); - $escape_next = false; - $quote = null; - - for ($i = 0; $i < $count; $i++) { - $token = $tokens[$i]; - - // are we inside a comment already? - if ($comment !== null) { - if ($inline && ($token === "\n" || $token === "\r\n")) { - unset($tokens[$comment]); - $comment = null; - } else { - unset($tokens[$i]); - $tokens[$comment] .= $token; - } - - if (!$inline && ($token === '*/')) { - if ($comment is nonnull) { - unset($tokens[$comment]); - } - $comment = null; - } - continue; - } - - if (!$escape_next && C\contains_key(keyset['\'', '"'], $token)) { - if ($quote !== null && $quote === $token) { - $quote = null; - } else if ($quote !== null) { - continue; - } else { - $quote = $token; - } - } - - // we need to handle sequences that look like comments, but are inside quoted strings. to do that, we also need to know when quoted strings start and end - // since a comment could contain a quote, and a quote could contain a comment, we can't safely process either first without being aware of the other - // so first we check if the next token should be escaped, and then if it's a quote character - // checking for !$escape_next here checks for \\\\ sequences - if ($token === '\\' && !$escape_next) { - $escape_next = true; - } else { - $escape_next = false; - } - - // if we are inside a quoted string, do not check for comment sequences - if ($quote !== null) { - continue; - } - - // MySQL requires a space after double dash for it to be counted as a comment: https://dev.mysql.com/doc/refman/5.7/en/ansi-diff-comments.html - if ($token === '--') { - $comment = $i; - $inline = true; - } - - // hash comments don't require a space - if (Str\starts_with($token, '#')) { - $comment = $i; - $inline = true; - } - - if ($token === '/*') { - $comment = $i; - $inline = false; - } - } - - // re-key - return dict(vec($tokens)); - } - - private function groupEscapeSequences(dict $tokens): dict { - $tokenCount = C\count($tokens); - $i = 0; - - while ($i < $tokenCount) { - if (Str\ends_with($tokens[$i], '\\')) { - $i++; - if (C\contains_key($tokens, $i)) { - $tokens[$i - 1] .= $tokens[$i]; - unset($tokens[$i]); - } - } - $i++; - } - - // re-key - return dict(vec($tokens)); - } - - private function groupQuotedTokens(dict $tokens): vec { - $i = 0; - $count = C\count($tokens); - while ($i < $count) { - $token = $tokens[$i]; - - // single quotes, double quotes, or backticks - // when we find a quote, seek forward to the next matching quote and combine all tokens within - if (C\contains_key(keyset['\'', '"', '`'], $token)) { - $quote = $token; - $quote_start = $i; - $i++; - $found_match = false; - while ($i < $count) { - $t = $tokens[$i]; - // up to and including when we find a match, unroll each new token and add it to the current $token - $token .= $t; - unset($tokens[$i]); - $i++; - if ($t === $quote) { - $found_match = true; - // if the quotes are followed by a dot, add it in too (e.g. `database`.tablename) - if ($i < $count && \mb_substr($tokens[$i], 0, 1) === '.') { - $t = $tokens[$i]; - $token .= $t; - unset($tokens[$i]); - $i++; - } - break; - } - } - - if (!$found_match) { - throw new SQLFakeParseException("Unbalanced quote $quote"); - } - - $tokens[$quote_start] = $token; - continue; - } - - $i++; - } - - return vec($tokens); - } + // this regex contains all MySQL operands and words patterns (including whitespace) that can separate SQL Tokens + // the longer ones come first which is important because of how preg_split works + private static string $token_split_regex = + '/(\<\=\>|\r\n|\!\=|\>\=|\<\=|\<\>|\<\<|\>\>|\:\=|&&|\|\||\:\=|\/\*|\*\/|\-\-|\>|\<|\||\=|\^|\(|\)|\t|\n|\'|"|`|,|@|\s|\+|\-|\*|\/|;|\\\)/'; + + public function lex(string $sql): vec { + + $tokens = preg_split(self::$token_split_regex, $sql, null, \PREG_SPLIT_DELIM_CAPTURE | \PREG_SPLIT_NO_EMPTY) + |> dict($$); + + // first process SQL comments, grouping them into one token so that the rest of the statements don't need to worry about handling comments + // since comments are rare, save a bit of time by skipping this if the SQL can't possibly contain comments + if (preg_match('![/#-]!', $sql)) { + $tokens = $this->groupComments($tokens); + } + + // for a string like 'Scott\\'s' this puts the ' character in the right place, necessary for balance_quotes to works + // since backslashes are rare, save time by skipping this if backslashes aren't present + if (Str\contains($sql, '\\')) { + $tokens = $this->groupEscapeSequences($tokens); + } + + // quotes are common, always do this + return $this->groupQuotedTokens($tokens); + } + + // group comments into a single token, then remove them + private function groupComments(dict $tokens): dict { + + $comment = null; + $inline = false; + $count = C\count($tokens); + $escape_next = false; + $quote = null; + + for ($i = 0; $i < $count; $i++) { + $token = $tokens[$i]; + + // are we inside a comment already? + if ($comment !== null) { + if ($inline && ($token === "\n" || $token === "\r\n")) { + unset($tokens[$comment]); + $comment = null; + } else { + unset($tokens[$i]); + $tokens[$comment] .= $token; + } + + if (!$inline && ($token === '*/')) { + if ($comment is nonnull) { + unset($tokens[$comment]); + } + $comment = null; + } + continue; + } + + if (!$escape_next && C\contains_key(keyset['\'', '"'], $token)) { + if ($quote !== null && $quote === $token) { + $quote = null; + } else if ($quote !== null) { + continue; + } else { + $quote = $token; + } + } + + // we need to handle sequences that look like comments, but are inside quoted strings. to do that, we also need to know when quoted strings start and end + // since a comment could contain a quote, and a quote could contain a comment, we can't safely process either first without being aware of the other + // so first we check if the next token should be escaped, and then if it's a quote character + // checking for !$escape_next here checks for \\\\ sequences + if ($token === '\\' && !$escape_next) { + $escape_next = true; + } else { + $escape_next = false; + } + + // if we are inside a quoted string, do not check for comment sequences + if ($quote !== null) { + continue; + } + + // MySQL requires a space after double dash for it to be counted as a comment: https://dev.mysql.com/doc/refman/5.7/en/ansi-diff-comments.html + if ($token === '--') { + $comment = $i; + $inline = true; + } + + // hash comments don't require a space + if (Str\starts_with($token, '#')) { + $comment = $i; + $inline = true; + } + + if ($token === '/*') { + $comment = $i; + $inline = false; + } + } + + // re-key + return dict(vec($tokens)); + } + + private function groupEscapeSequences(dict $tokens): dict { + $tokenCount = C\count($tokens); + $i = 0; + + while ($i < $tokenCount) { + if (Str\ends_with($tokens[$i], '\\')) { + $i++; + if (C\contains_key($tokens, $i)) { + $tokens[$i - 1] .= $tokens[$i]; + unset($tokens[$i]); + } + } + $i++; + } + + // re-key + return dict(vec($tokens)); + } + + private function groupQuotedTokens(dict $tokens): vec { + $i = 0; + $count = C\count($tokens); + while ($i < $count) { + $token = $tokens[$i]; + + // single quotes, double quotes, or backticks + // when we find a quote, seek forward to the next matching quote and combine all tokens within + if (C\contains_key(keyset['\'', '"', '`'], $token)) { + $quote = $token; + $quote_start = $i; + $i++; + $found_match = false; + while ($i < $count) { + $t = $tokens[$i]; + // up to and including when we find a match, unroll each new token and add it to the current $token + $token .= $t; + unset($tokens[$i]); + $i++; + if ($t === $quote) { + $found_match = true; + // if the quotes are followed by a dot, add it in too (e.g. `database`.tablename) + if ($i < $count && \mb_substr($tokens[$i], 0, 1) === '.') { + $t = $tokens[$i]; + $token .= $t; + unset($tokens[$i]); + $i++; + } + break; + } + } + + if (!$found_match) { + throw new SQLFakeParseException("Unbalanced quote $quote"); + } + + $tokens[$quote_start] = $token; + continue; + } + + $i++; + } + + return vec($tokens); + } } diff --git a/src/Parser/LimitParser.php b/src/Parser/LimitParser.php index c3ae781..c2de176 100644 --- a/src/Parser/LimitParser.php +++ b/src/Parser/LimitParser.php @@ -5,44 +5,44 @@ // parse the LIMIT clause, which can be used for SELECT, UPDATE, or DELETE final class LimitParser { - public function __construct(private int $pointer, private token_list $tokens) {} + public function __construct(private int $pointer, private token_list $tokens) {} - public function parse(): (int, limit_clause) { + public function parse(): (int, limit_clause) { - // if we got here, the first token had better be LIMIT - if ($this->tokens[$this->pointer]['value'] !== 'LIMIT') { - throw new SQLFakeParseException('Parser error: expected LIMIT'); - } - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null || $next['type'] !== TokenType::NUMERIC_CONSTANT) { - throw new SQLFakeParseException('Expected integer after LIMIT'); - } - $limit = (int)($next['value']); - $offset = 0; - $next = $this->tokens[$this->pointer + 1] ?? null; - if ($next !== null) { - if ($next['value'] === 'OFFSET') { - $this->pointer += 2; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null || $next['type'] !== TokenType::NUMERIC_CONSTANT) { - throw new SQLFakeParseException('Expected integer after OFFSET'); - } - $offset = (int)($next['value']); - } else if ($next['value'] === ',') { - $this->pointer += 2; - $next = $this->tokens[$this->pointer] ?? null; - if ($next === null || $next['type'] !== TokenType::NUMERIC_CONSTANT) { - throw new SQLFakeParseException('Expected integer after OFFSET'); - } + // if we got here, the first token had better be LIMIT + if ($this->tokens[$this->pointer]['value'] !== 'LIMIT') { + throw new SQLFakeParseException('Parser error: expected LIMIT'); + } + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null || $next['type'] !== TokenType::NUMERIC_CONSTANT) { + throw new SQLFakeParseException('Expected integer after LIMIT'); + } + $limit = (int)($next['value']); + $offset = 0; + $next = $this->tokens[$this->pointer + 1] ?? null; + if ($next !== null) { + if ($next['value'] === 'OFFSET') { + $this->pointer += 2; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null || $next['type'] !== TokenType::NUMERIC_CONSTANT) { + throw new SQLFakeParseException('Expected integer after OFFSET'); + } + $offset = (int)($next['value']); + } else if ($next['value'] === ',') { + $this->pointer += 2; + $next = $this->tokens[$this->pointer] ?? null; + if ($next === null || $next['type'] !== TokenType::NUMERIC_CONSTANT) { + throw new SQLFakeParseException('Expected integer after OFFSET'); + } - // in LIMIT 1, 100 the offset is 1 and 100 is the row count, so swap them here - // confusing, right? - $offset = $limit; - $limit = (int)($next['value']); - } - } + // in LIMIT 1, 100 the offset is 1 and 100 is the row count, so swap them here + // confusing, right? + $offset = $limit; + $limit = (int)($next['value']); + } + } - return tuple($this->pointer, shape('rowcount' => $limit, 'offset' => $offset)); - } + return tuple($this->pointer, shape('rowcount' => $limit, 'offset' => $offset)); + } } diff --git a/src/Parser/OrderByParser.php b/src/Parser/OrderByParser.php index 27f9d22..643c520 100644 --- a/src/Parser/OrderByParser.php +++ b/src/Parser/OrderByParser.php @@ -7,64 +7,64 @@ // parse the ORDER BY clause, which can be used for SELECT, UPDATE, or DELETE final class OrderByParser { - public function __construct( - private int $pointer, - private token_list $tokens, - // this one is only used for SELECT queries. - private ?vec $selectExpressions = null, - ) {} + public function __construct( + private int $pointer, + private token_list $tokens, + // this one is only used for SELECT queries. + private ?vec $selectExpressions = null, + ) {} - public function parse(): (int, order_by_clause) { + public function parse(): (int, order_by_clause) { - // if we got here, the first token had better be ORDER - if ($this->tokens[$this->pointer]['value'] !== 'ORDER') { - throw new SQLFakeParseException('Parser error: expected ORDER'); - } + // if we got here, the first token had better be ORDER + if ($this->tokens[$this->pointer]['value'] !== 'ORDER') { + throw new SQLFakeParseException('Parser error: expected ORDER'); + } - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - $expressions = vec[]; - if ($next === null || $next['value'] !== 'BY') { - throw new SQLFakeParseException('Expected BY after ORDER'); - } + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + $expressions = vec[]; + if ($next === null || $next['value'] !== 'BY') { + throw new SQLFakeParseException('Expected BY after ORDER'); + } - while (true) { - $expression_parser = new ExpressionParser($this->tokens, $this->pointer); - if ($this->selectExpressions is nonnull) { - $expression_parser->setSelectExpressions($this->selectExpressions); - } - list($this->pointer, $expression) = $expression_parser->buildWithPointer(); + while (true) { + $expression_parser = new ExpressionParser($this->tokens, $this->pointer); + if ($this->selectExpressions is nonnull) { + $expression_parser->setSelectExpressions($this->selectExpressions); + } + list($this->pointer, $expression) = $expression_parser->buildWithPointer(); - // any constants in the ORDER BY must be positional references - if ($expression is ConstantExpression) { - // grab the associated position from the select list - $position = (int)($expression->value); + // any constants in the ORDER BY must be positional references + if ($expression is ConstantExpression) { + // grab the associated position from the select list + $position = (int)($expression->value); - $expression = $this->selectExpressions[$position - 1] ?? null; - if ($expression is null) { - throw new SQLFakeParseException("ORDER BY positional field $position not found in SELECT list"); - } - } + $expression = $this->selectExpressions[$position - 1] ?? null; + if ($expression is null) { + throw new SQLFakeParseException("ORDER BY positional field $position not found in SELECT list"); + } + } - $next = $this->tokens[$this->pointer + 1] ?? null; + $next = $this->tokens[$this->pointer + 1] ?? null; - // default to ASC - $sort_direction = SortDirection::ASC; - if ($next !== null && C\contains_key(keyset['ASC', 'DESC'], $next['value'])) { - $this->pointer++; - $sort_direction = SortDirection::assert($next['value']); - $next = $this->tokens[$this->pointer + 1] ?? null; - } + // default to ASC + $sort_direction = SortDirection::ASC; + if ($next !== null && C\contains_key(keyset['ASC', 'DESC'], $next['value'])) { + $this->pointer++; + $sort_direction = SortDirection::assert($next['value']); + $next = $this->tokens[$this->pointer + 1] ?? null; + } - $expressions[] = shape('expression' => $expression, 'direction' => $sort_direction); + $expressions[] = shape('expression' => $expression, 'direction' => $sort_direction); - // skip over commas and continue the processing, but if it's any other token break out of the loop - if ($next === null || $next['value'] !== ',') { - break; - } - $this->pointer++; - } + // skip over commas and continue the processing, but if it's any other token break out of the loop + if ($next === null || $next['value'] !== ',') { + break; + } + $this->pointer++; + } - return tuple($this->pointer, $expressions); - } + return tuple($this->pointer, $expressions); + } } diff --git a/src/Parser/SQLParser.php b/src/Parser/SQLParser.php index 6af568d..dfab837 100644 --- a/src/Parser/SQLParser.php +++ b/src/Parser/SQLParser.php @@ -6,606 +6,606 @@ final class SQLParser { - public static function parse(string $sql): Query { - // memoize hit rate on write queries is very low - so only memoize selects to avoid ballooning memory usage - if (Str\starts_with_ci($sql, 'SELECT')) { - return static::parseMemoized($sql); - } - return static::parseImpl($sql); - } - - private static function parseImpl(string $sql): Query { - $tokens = (new SQLLexer())->lex($sql); - $tokens = self::buildTokenListFromLexemes($tokens); - - $token = $tokens[0]; - // handle a query like (SELECT 1), just strip the surrounding parens - if ($token['type'] === TokenType::PAREN) { - $close = self::findMatchingParen(0, $tokens); - $tokens = Vec\slice($tokens, 1, $close - 1); - $token = $tokens[0]; - } - - if ($token['type'] !== TokenType::CLAUSE) { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - - switch ($token['value']) { - case 'SELECT': - $select = new SelectParser(0, $tokens, $sql); - list($pointer, $query) = $select->parse(); - // we still have something left here after parsing the whole top level query? hopefully it's a multi-query keyword - if (C\contains_key($tokens, $pointer)) { - $next = $tokens[$pointer] ?? null; - while ($next !== null && C\contains_key(keyset['UNION', 'INTERSECT', 'EXCEPT'], $next['value'])) { - $type = $next['value']; - if ($next['value'] === 'UNION') { - $next_plus = $tokens[$pointer + 1]; - if ($next_plus['value'] === 'ALL') { - $type = 'UNION_ALL'; - $pointer++; - } - if ($next_plus['value'] === 'DISTINCT') { - $pointer++; - } - } - $pointer++; - $select = new SelectParser($pointer, $tokens, $sql); - list($pointer, $q) = $select->parse(); - $query->addMultiQuery(MultiOperand::assert($type), $q); - - $next = $tokens[$pointer] ?? null; - } - } - return $query; - case 'UPDATE': - $update = new UpdateParser($tokens, $sql); - return $update->parse(); - case 'DELETE': - $delete = new DeleteParser($tokens, $sql); - return $delete->parse(); - case 'INSERT': - $insert = new InsertParser($tokens, $sql); - return $insert->parse(); - default: - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - - throw new SQLFakeParseException('Parse error: unexpected end of input'); - } - - <<__Memoize>> - private static function parseMemoized(string $sql): Query { - return static::parseImpl($sql); - } - - /* - * Lexemes are just a vec of strings from the lexer - * This builds them into typed token shapes based on lists of reserved keywords and surrounding context - */ - private static function buildTokenListFromLexemes(vec $tokens): token_list { - $out = vec[]; - $count = C\count($tokens); - foreach ($tokens as $i => $token) { - // skip white space, but tack it onto the tokens in another field for when we need to assemble the expression list - if (Str\trim($token) === '') { - $k = C\last_key($out); - if ($k !== null) { - $previous = $out[$k]; - $previous['raw'] .= $token; - $out[$k] = $previous; - } - continue; - } - - if (Str\to_int($token) is nonnull) { - $out[] = shape( - 'type' => TokenType::NUMERIC_CONSTANT, - 'value' => $token, - 'raw' => $token, - ); - continue; - } else if (C\contains_key(keyset['\'', '"'], $token[0])) { - // chop off the quotes before storing the value - $raw = $token; - $token = Str\slice($token, 1, Str\length($token) - 2); - - // unescape everything except for % and _ (which only get unescaped during LIKE operations) - // there are a few other special sequnces we leave unescaped like \r, \n, \t, \b, \Z, \0 - // https://dev.mysql.com/doc/refman/5.7/en/string-literals.html - if (Str\contains($token, '\\')) { - $token_replaced = ''; - $escape_next = false; - for ($i = 0; $i < Str\length($token); $i++) { - if ($escape_next) { - switch ($token[$i]) { - case 'r': - $token_replaced .= "\r"; - break; - case 'n': - $token_replaced .= "\n"; - break; - case '0': - $token_replaced .= "\0"; - break; - case 't': - $token_replaced .= "\t"; - break; - case '\\': - $token_replaced .= '\\'; - break; - case '%': - case '_': - // these stay unescaped unless used in LIKE - $token_replaced .= "\\{$token[$i]}"; - break; - default: - $token_replaced .= $token[$i]; - } - $escape_next = false; - } else if ($token[$i] === '\\') { - $escape_next = true; - } else { - $token_replaced .= $token[$i]; - } - } - $token = $token_replaced; - } - $out[] = shape( - 'type' => TokenType::STRING_CONSTANT, - 'value' => $token, - 'raw' => $raw, - ); - continue; - } else if ($token[0] === '`') { - $raw = $token; - // Only chop off the ` if it's fully wrapping the identifier - if (Str\ends_with($token, '`')) { - $token = Str\strip_prefix($token, '`') |> Str\strip_suffix($$, '`'); - } else if (Str\ends_with($token, '`.')) { - // for `foo`., it becomes foo. - $token = Str\strip_prefix($token, '`') |> Str\strip_suffix($$, '`.') |> $$.'.'; - } - - // if we find an identifier and previous token ended with ., smush them together - $previous_key = C\last_key($out); - if ( - $previous_key is nonnull && - $out[$previous_key]['type'] === TokenType::IDENTIFIER && - Str\ends_with($out[$previous_key]['value'], '.') - ) { - $out[$previous_key]['value'] .= $token; - $out[$previous_key]['raw'] .= $raw; - continue; - } - - $out[] = shape( - 'type' => TokenType::IDENTIFIER, - 'value' => $token, - 'raw' => $raw, - ); - continue; - } else if ($token[0] === '(') { - $out[] = shape('type' => TokenType::PAREN, 'value' => $token, 'raw' => $token); - continue; - } else if ($token === '*') { - // the * character is special because it's sometimes an operator and most of the time it means "all columns" - $k = C\last_key($out); - if ($k === null) { - throw new SQLFakeParseException('Parse error: unexpected *'); - } - $previous = $out[$k]; - $out[$k] = $previous; - if ( - !C\contains_key( - keyset[ - TokenType::NUMERIC_CONSTANT, - TokenType::BOOLEAN_CONSTANT, - TokenType::STRING_CONSTANT, - TokenType::NULL_CONSTANT, - TokenType::IDENTIFIER, - ], - $previous['type'], - ) && - $previous['value'] !== ')' - ) { - $out[] = shape( - 'type' => TokenType::IDENTIFIER, - 'value' => $token, - 'raw' => $token, - ); - continue; - } else if ($previous['type'] === TokenType::IDENTIFIER && Str\ends_with($previous['value'], '.')) { - // previous ended like "foo.", we should keep "foo.*" together as one token - $previous['value'] .= $token; - $previous['raw'] .= $token; - $out[$k] = $previous; - continue; - } - } else if ($token === '-' || $token === '+') { - // these operands can be binary or unary operands - // for example "SELECT -5" or "SELECT 7 - 5" are both valid, in the first case it's a unary op - // we don't just combine it into the constant, because it's also valid for columns like SELECT -some_column FROM table - // similar to * we can identify this now based on context - $k = C\last_key($out); - if ($k === null) { - throw new SQLFakeParseException("Parse error: unexpected {$token}"); - } - $previous = $out[$k]; - if ( - !C\contains_key( - keyset[ - TokenType::NUMERIC_CONSTANT, - TokenType::BOOLEAN_CONSTANT, - TokenType::STRING_CONSTANT, - TokenType::NULL_CONSTANT, - TokenType::IDENTIFIER, - ], - $previous['type'], - - ) && - $previous['value'] !== ')' - ) { - if ($token === '-') { - $op = 'UNARY_MINUS'; - } else { - $op = 'UNARY_PLUS'; - } - $out[] = shape( - 'type' => TokenType::OPERATOR, - 'value' => $op, - 'raw' => $token, - ); - continue; - } - } - - $token_upper = Str\uppercase($token); - - if ($token_upper === 'NULL') { - $out[] = shape( - 'type' => TokenType::NULL_CONSTANT, - 'value' => $token, - 'raw' => $token, - ); - } else if ($token_upper === 'TRUE') { - $out[] = shape( - 'type' => TokenType::BOOLEAN_CONSTANT, - 'value' => '1', - 'raw' => $token, - ); - } else if ($token_upper === 'FALSE') { - $out[] = shape( - 'type' => TokenType::BOOLEAN_CONSTANT, - 'value' => '0', - 'raw' => $token, - ); - } else if (C\contains_key(self::CLAUSES, $token_upper)) { - $out[] = shape( - 'type' => TokenType::CLAUSE, - 'value' => $token_upper, - 'raw' => $token, - ); - } else if ( - C\contains_key(self::OPERATORS, $token_upper) && - !self::isFunctionVersionOfOperator($token_upper, $i, $count, $tokens) - ) { - $out[] = shape( - 'type' => TokenType::OPERATOR, - 'value' => $token_upper, - 'raw' => $token, - ); - } else if (C\contains_key(self::RESERVED_WORDS, $token_upper)) { - $out[] = shape( - 'type' => TokenType::RESERVED, - 'value' => $token_upper, - 'raw' => $token, - ); - } else if (C\contains_key(self::SEPARATORS, $token_upper)) { - $out[] = shape( - 'type' => TokenType::SEPARATOR, - 'value' => $token_upper, - 'raw' => $token, - ); - } else if ($i < $count - 1 && $tokens[$i + 1] === '(') { - $out[] = shape( - 'type' => TokenType::SQLFUNCTION, - 'value' => $token_upper, - 'raw' => $token, - ); - } else { - // if we find an identifier and previous token ended with ., smush them together - $previous_key = C\last_key($out); - if ( - $previous_key is nonnull && - $out[$previous_key]['type'] === TokenType::IDENTIFIER && - Str\ends_with($out[$previous_key]['value'], '.') - ) { - $out[$previous_key]['value'] .= $token; - continue; - } - $out[] = shape( - 'type' => TokenType::IDENTIFIER, - 'value' => $token, - 'raw' => $token, - ); - } - } - return $out; - } - - /** - * There seem to be a few operators that also exists as functions. MOD() for example. - * So we check if this particular find is an operator or a function. - */ - private static function isFunctionVersionOfOperator( - string $token_upper, - int $i, - int $count, - vec $tokens, - ): bool { - return $token_upper === 'MOD' && $i < $count - 1 && $tokens[$i + 1] === '('; - } - - public static function findMatchingParen(int $pointer, token_list $tokens): int { - $paren_count = 0; - $remaining_tokens = Vec\drop($tokens, $pointer); - foreach ($remaining_tokens as $i => $token) { - if ($token['type'] === TokenType::PAREN) { - $paren_count++; - } else if ($token['type'] === TokenType::SEPARATOR && $token['value'] === ')') { - $paren_count--; - if ($paren_count === 0) { - return $pointer + $i; - } - } - } - - // if we get here, we didn't find the close - throw new SQLFakeParseException("Unclosed parentheses at index $pointer"); - } - - /* - * Skip over index hints, but might as well still syntax validate them while doing so - * Examples of index hints: - * FROM table1 USE INDEX (col1_index,col2_index) - * FROM t1 USE INDEX (i1) IGNORE INDEX FOR ORDER BY (i2) - */ - public static function skipIndexHints(int $pointer, token_list $tokens): int { - $next_pointer = $pointer + 1; - $next = $tokens[$next_pointer] ?? null; - while ( - $next !== null && - $next['type'] === TokenType::RESERVED && - C\contains_key(keyset['USE', 'IGNORE', 'FORCE'], $next['value']) - ) { - $pointer += 2; - $hint_type = $next['value']; - $next = $tokens[$pointer] ?? null; - if ($next === null || !C\contains_key(keyset['INDEX', 'KEY'], $next['value'])) { - throw new SQLFakeParseException('Expected INDEX or KEY in index hint'); - } - - $pointer++; - $next = $tokens[$pointer] ?? null; - if ($next === null) { - // USE hint is allowed to stop at "USE INDEX" which means "use no indexes" - if ($hint_type === 'USE') { - $pointer--; - return $pointer; - } - throw new SQLFakeParseException('Expected expected FOR or index list in index hint'); - } - - if ($next['value'] === 'FOR') { - $pointer++; - $next = $tokens[$pointer] ?? null; - if ($next === null) { - throw new SQLFakeParseException('Expected JOIN, ORDER BY, or GROUP BY after FOR in index hint'); - } else if ($next['value'] === 'JOIN') { - //this is fine - $pointer++; - $next = $tokens[$pointer] ?? null; - } else if (C\contains_key(keyset['GROUP', 'ORDER'], $next['value'])) { - $pointer++; - $next = $tokens[$pointer] ?? null; - if ($next === null || $next['value'] !== 'BY') { - throw new SQLFakeParseException('Expected BY in index hint after GROUP or ORDER'); - } - - $pointer++; - $next = $tokens[$pointer] ?? null; - } else { - throw new SQLFakeParseException('Expected JOIN, ORDER BY, or GROUP BY after FOR in index hint'); - } - } - - if ($next === null || $next['type'] !== TokenType::PAREN) { - // USE hint is allowed to stop at "USE INDEX" which means "use no indexes" - if ($hint_type === 'USE') { - $pointer--; - return $pointer; - } - throw new SQLFakeParseException('Expected index expression after index hint'); - } - - $closing_paren_pointer = SQLParser::findMatchingParen($pointer, $tokens); - $arg_tokens = Vec\slice($tokens, $pointer + 1, $closing_paren_pointer - $pointer - 1); - if (!C\count($arg_tokens)) { - throw new SQLFakeParseException('Expected at least one argument to index hint'); - } - $count = 0; - foreach ($arg_tokens as $arg) { - $count++; - if ($count % 2 === 1) { - if ($arg['type'] !== TokenType::IDENTIFIER) { - throw new SQLFakeParseException('Expected identifier in index hint'); - } - } else if ($arg['value'] !== ',') { - throw new SQLFakeParseException('Expected , or ) after index hint'); - } - } - - $pointer = $closing_paren_pointer; - - // you can have multiple index hints... so if the next token starts another index hint we can go back into this while loop - $next_pointer = $pointer + 1; - $next = $tokens[$next_pointer] ?? null; - - } - - return $pointer; - } - - /* - * Skip over lock hints, but might as well still syntax validate them while doing so - * Examples of index hints: - * FROM table1 WHERE ... FOR UPDATE - * FROM table1 WHERE ... LOCK IN SHARE MODE - */ - public static function skipLockHints(int $pointer, token_list $tokens): int { - $next_pointer = $pointer + 1; - $next = $tokens[$next_pointer] ?? null; - - if ($next !== null && $next['type'] === TokenType::RESERVED) { - if ($next['value'] === 'FOR') { - // skip over FOR UDPATE while validating it - $next_pointer++; - $next = $tokens[$next_pointer] ?? null; - if ($next === null) { - throw new SQLFakeParseException('Expected keyword after FOR'); - } - // skip over FOR UPDATE - if ($next['value'] === 'UPDATE') { - return $pointer + 2; - } - - throw new SQLFakeParseException("Unexpected keyword {$next['value']} after FOR"); - } else if ($next['value'] === 'LOCK') { - // skip over LOCK IN SHARE MODE while validating it - $expected = vec['IN', 'SHARE', 'MODE']; - foreach ($expected as $index => $keyword) { - $next = $tokens[$next_pointer + $index + 1] ?? null; - if ($next === null || $next['value'] !== $keyword) { - throw new SQLFakeParseException('Unexpected keyword near LOCK'); - } - } - return $pointer + 4; - } - } - - return $pointer; - } - - const keyset CLAUSES = keyset[ - 'SELECT', - 'FROM', - 'WHERE', - 'GROUP', - 'HAVING', - 'LIMIT', - 'ORDER', - 'UPDATE', - 'SET', - 'DELETE', - 'UNION', - 'EXCEPT', - 'INTERSECT', - 'INSERT', - 'VALUES', - ]; - - // left PAREN is not in here because it triggers special logic, so it has its own type of PAREN - const keyset SEPARATORS = keyset[ - ')', - ',', - ';', - ]; - - const keyset OPERATORS = keyset[ - 'INTERVAL', - 'COLLATE', - '!', - '~', - '^', - '*', - '/', - 'DIV', - '%', - 'MOD', - '-', - '+', - '<<', - '>>', - '&', - '|', - '=', - '<=>', - '>=', - '>', - '<=', - '<', - '<>', - '!=', - 'IS', - 'LIKE', - 'REGEXP', - 'IN', - 'EXISTS', - 'BETWEEN', - 'CASE', - 'WHEN', - 'THEN', - 'ELSE', - 'END', - 'NOT', - 'AND', - '&&', - 'XOR', - 'OR', - '||', - ]; - - const keyset RESERVED_WORDS = keyset[ - 'ASC', - 'DESC', - 'AS', - 'WITH', - 'ON', - 'OFFSET', - 'BY', - 'INTO', - 'ALL', - 'DISTINCT', - 'DISTINCTROW', - 'SQL_CALC_FOUND_ROWS', - 'HIGH_PRIORITY', - 'SQL_SMALL_RESULT', - 'SQL_BIG_RESULT', - 'SQL_BUFFER_RESULT', - 'SQL_CACHE', - 'SQL_NO_CACHE', - 'JOIN', - 'INNER', - 'OUTER', - 'LEFT', - 'RIGHT', - 'STRAIGHT_JOIN', - 'NATURAL', - 'USING', - 'CROSS', - 'USE', - 'IGNORE', - 'FORCE', - 'PARTITION', - 'ROLLUP', - 'INDEX', - 'KEY', - 'FOR', - 'LOCK', - 'DUPLICATE', - 'DELAYED', - 'LOW_PRIORITY', - 'HIGH_PRIORITY', - ]; + public static function parse(string $sql): Query { + // memoize hit rate on write queries is very low - so only memoize selects to avoid ballooning memory usage + if (Str\starts_with_ci($sql, 'SELECT')) { + return static::parseMemoized($sql); + } + return static::parseImpl($sql); + } + + private static function parseImpl(string $sql): Query { + $tokens = (new SQLLexer())->lex($sql); + $tokens = self::buildTokenListFromLexemes($tokens); + + $token = $tokens[0]; + // handle a query like (SELECT 1), just strip the surrounding parens + if ($token['type'] === TokenType::PAREN) { + $close = self::findMatchingParen(0, $tokens); + $tokens = Vec\slice($tokens, 1, $close - 1); + $token = $tokens[0]; + } + + if ($token['type'] !== TokenType::CLAUSE) { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + + switch ($token['value']) { + case 'SELECT': + $select = new SelectParser(0, $tokens, $sql); + list($pointer, $query) = $select->parse(); + // we still have something left here after parsing the whole top level query? hopefully it's a multi-query keyword + if (C\contains_key($tokens, $pointer)) { + $next = $tokens[$pointer] ?? null; + while ($next !== null && C\contains_key(keyset['UNION', 'INTERSECT', 'EXCEPT'], $next['value'])) { + $type = $next['value']; + if ($next['value'] === 'UNION') { + $next_plus = $tokens[$pointer + 1]; + if ($next_plus['value'] === 'ALL') { + $type = 'UNION_ALL'; + $pointer++; + } + if ($next_plus['value'] === 'DISTINCT') { + $pointer++; + } + } + $pointer++; + $select = new SelectParser($pointer, $tokens, $sql); + list($pointer, $q) = $select->parse(); + $query->addMultiQuery(MultiOperand::assert($type), $q); + + $next = $tokens[$pointer] ?? null; + } + } + return $query; + case 'UPDATE': + $update = new UpdateParser($tokens, $sql); + return $update->parse(); + case 'DELETE': + $delete = new DeleteParser($tokens, $sql); + return $delete->parse(); + case 'INSERT': + $insert = new InsertParser($tokens, $sql); + return $insert->parse(); + default: + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + + throw new SQLFakeParseException('Parse error: unexpected end of input'); + } + + <<__Memoize>> + private static function parseMemoized(string $sql): Query { + return static::parseImpl($sql); + } + + /* + * Lexemes are just a vec of strings from the lexer + * This builds them into typed token shapes based on lists of reserved keywords and surrounding context + */ + private static function buildTokenListFromLexemes(vec $tokens): token_list { + $out = vec[]; + $count = C\count($tokens); + foreach ($tokens as $i => $token) { + // skip white space, but tack it onto the tokens in another field for when we need to assemble the expression list + if (Str\trim($token) === '') { + $k = C\last_key($out); + if ($k !== null) { + $previous = $out[$k]; + $previous['raw'] .= $token; + $out[$k] = $previous; + } + continue; + } + + if (Str\to_int($token) is nonnull) { + $out[] = shape( + 'type' => TokenType::NUMERIC_CONSTANT, + 'value' => $token, + 'raw' => $token, + ); + continue; + } else if (C\contains_key(keyset['\'', '"'], $token[0])) { + // chop off the quotes before storing the value + $raw = $token; + $token = Str\slice($token, 1, Str\length($token) - 2); + + // unescape everything except for % and _ (which only get unescaped during LIKE operations) + // there are a few other special sequnces we leave unescaped like \r, \n, \t, \b, \Z, \0 + // https://dev.mysql.com/doc/refman/5.7/en/string-literals.html + if (Str\contains($token, '\\')) { + $token_replaced = ''; + $escape_next = false; + for ($i = 0; $i < Str\length($token); $i++) { + if ($escape_next) { + switch ($token[$i]) { + case 'r': + $token_replaced .= "\r"; + break; + case 'n': + $token_replaced .= "\n"; + break; + case '0': + $token_replaced .= "\0"; + break; + case 't': + $token_replaced .= "\t"; + break; + case '\\': + $token_replaced .= '\\'; + break; + case '%': + case '_': + // these stay unescaped unless used in LIKE + $token_replaced .= "\\{$token[$i]}"; + break; + default: + $token_replaced .= $token[$i]; + } + $escape_next = false; + } else if ($token[$i] === '\\') { + $escape_next = true; + } else { + $token_replaced .= $token[$i]; + } + } + $token = $token_replaced; + } + $out[] = shape( + 'type' => TokenType::STRING_CONSTANT, + 'value' => $token, + 'raw' => $raw, + ); + continue; + } else if ($token[0] === '`') { + $raw = $token; + // Only chop off the ` if it's fully wrapping the identifier + if (Str\ends_with($token, '`')) { + $token = Str\strip_prefix($token, '`') |> Str\strip_suffix($$, '`'); + } else if (Str\ends_with($token, '`.')) { + // for `foo`., it becomes foo. + $token = Str\strip_prefix($token, '`') |> Str\strip_suffix($$, '`.') |> $$.'.'; + } + + // if we find an identifier and previous token ended with ., smush them together + $previous_key = C\last_key($out); + if ( + $previous_key is nonnull && + $out[$previous_key]['type'] === TokenType::IDENTIFIER && + Str\ends_with($out[$previous_key]['value'], '.') + ) { + $out[$previous_key]['value'] .= $token; + $out[$previous_key]['raw'] .= $raw; + continue; + } + + $out[] = shape( + 'type' => TokenType::IDENTIFIER, + 'value' => $token, + 'raw' => $raw, + ); + continue; + } else if ($token[0] === '(') { + $out[] = shape('type' => TokenType::PAREN, 'value' => $token, 'raw' => $token); + continue; + } else if ($token === '*') { + // the * character is special because it's sometimes an operator and most of the time it means "all columns" + $k = C\last_key($out); + if ($k === null) { + throw new SQLFakeParseException('Parse error: unexpected *'); + } + $previous = $out[$k]; + $out[$k] = $previous; + if ( + !C\contains_key( + keyset[ + TokenType::NUMERIC_CONSTANT, + TokenType::BOOLEAN_CONSTANT, + TokenType::STRING_CONSTANT, + TokenType::NULL_CONSTANT, + TokenType::IDENTIFIER, + ], + $previous['type'], + ) && + $previous['value'] !== ')' + ) { + $out[] = shape( + 'type' => TokenType::IDENTIFIER, + 'value' => $token, + 'raw' => $token, + ); + continue; + } else if ($previous['type'] === TokenType::IDENTIFIER && Str\ends_with($previous['value'], '.')) { + // previous ended like "foo.", we should keep "foo.*" together as one token + $previous['value'] .= $token; + $previous['raw'] .= $token; + $out[$k] = $previous; + continue; + } + } else if ($token === '-' || $token === '+') { + // these operands can be binary or unary operands + // for example "SELECT -5" or "SELECT 7 - 5" are both valid, in the first case it's a unary op + // we don't just combine it into the constant, because it's also valid for columns like SELECT -some_column FROM table + // similar to * we can identify this now based on context + $k = C\last_key($out); + if ($k === null) { + throw new SQLFakeParseException("Parse error: unexpected {$token}"); + } + $previous = $out[$k]; + if ( + !C\contains_key( + keyset[ + TokenType::NUMERIC_CONSTANT, + TokenType::BOOLEAN_CONSTANT, + TokenType::STRING_CONSTANT, + TokenType::NULL_CONSTANT, + TokenType::IDENTIFIER, + ], + $previous['type'], + + ) && + $previous['value'] !== ')' + ) { + if ($token === '-') { + $op = 'UNARY_MINUS'; + } else { + $op = 'UNARY_PLUS'; + } + $out[] = shape( + 'type' => TokenType::OPERATOR, + 'value' => $op, + 'raw' => $token, + ); + continue; + } + } + + $token_upper = Str\uppercase($token); + + if ($token_upper === 'NULL') { + $out[] = shape( + 'type' => TokenType::NULL_CONSTANT, + 'value' => $token, + 'raw' => $token, + ); + } else if ($token_upper === 'TRUE') { + $out[] = shape( + 'type' => TokenType::BOOLEAN_CONSTANT, + 'value' => '1', + 'raw' => $token, + ); + } else if ($token_upper === 'FALSE') { + $out[] = shape( + 'type' => TokenType::BOOLEAN_CONSTANT, + 'value' => '0', + 'raw' => $token, + ); + } else if (C\contains_key(self::CLAUSES, $token_upper)) { + $out[] = shape( + 'type' => TokenType::CLAUSE, + 'value' => $token_upper, + 'raw' => $token, + ); + } else if ( + C\contains_key(self::OPERATORS, $token_upper) && + !self::isFunctionVersionOfOperator($token_upper, $i, $count, $tokens) + ) { + $out[] = shape( + 'type' => TokenType::OPERATOR, + 'value' => $token_upper, + 'raw' => $token, + ); + } else if (C\contains_key(self::RESERVED_WORDS, $token_upper)) { + $out[] = shape( + 'type' => TokenType::RESERVED, + 'value' => $token_upper, + 'raw' => $token, + ); + } else if (C\contains_key(self::SEPARATORS, $token_upper)) { + $out[] = shape( + 'type' => TokenType::SEPARATOR, + 'value' => $token_upper, + 'raw' => $token, + ); + } else if ($i < $count - 1 && $tokens[$i + 1] === '(') { + $out[] = shape( + 'type' => TokenType::SQLFUNCTION, + 'value' => $token_upper, + 'raw' => $token, + ); + } else { + // if we find an identifier and previous token ended with ., smush them together + $previous_key = C\last_key($out); + if ( + $previous_key is nonnull && + $out[$previous_key]['type'] === TokenType::IDENTIFIER && + Str\ends_with($out[$previous_key]['value'], '.') + ) { + $out[$previous_key]['value'] .= $token; + continue; + } + $out[] = shape( + 'type' => TokenType::IDENTIFIER, + 'value' => $token, + 'raw' => $token, + ); + } + } + return $out; + } + + /** + * There seem to be a few operators that also exists as functions. MOD() for example. + * So we check if this particular find is an operator or a function. + */ + private static function isFunctionVersionOfOperator( + string $token_upper, + int $i, + int $count, + vec $tokens, + ): bool { + return $token_upper === 'MOD' && $i < $count - 1 && $tokens[$i + 1] === '('; + } + + public static function findMatchingParen(int $pointer, token_list $tokens): int { + $paren_count = 0; + $remaining_tokens = Vec\drop($tokens, $pointer); + foreach ($remaining_tokens as $i => $token) { + if ($token['type'] === TokenType::PAREN) { + $paren_count++; + } else if ($token['type'] === TokenType::SEPARATOR && $token['value'] === ')') { + $paren_count--; + if ($paren_count === 0) { + return $pointer + $i; + } + } + } + + // if we get here, we didn't find the close + throw new SQLFakeParseException("Unclosed parentheses at index $pointer"); + } + + /* + * Skip over index hints, but might as well still syntax validate them while doing so + * Examples of index hints: + * FROM table1 USE INDEX (col1_index,col2_index) + * FROM t1 USE INDEX (i1) IGNORE INDEX FOR ORDER BY (i2) + */ + public static function skipIndexHints(int $pointer, token_list $tokens): int { + $next_pointer = $pointer + 1; + $next = $tokens[$next_pointer] ?? null; + while ( + $next !== null && + $next['type'] === TokenType::RESERVED && + C\contains_key(keyset['USE', 'IGNORE', 'FORCE'], $next['value']) + ) { + $pointer += 2; + $hint_type = $next['value']; + $next = $tokens[$pointer] ?? null; + if ($next === null || !C\contains_key(keyset['INDEX', 'KEY'], $next['value'])) { + throw new SQLFakeParseException('Expected INDEX or KEY in index hint'); + } + + $pointer++; + $next = $tokens[$pointer] ?? null; + if ($next === null) { + // USE hint is allowed to stop at "USE INDEX" which means "use no indexes" + if ($hint_type === 'USE') { + $pointer--; + return $pointer; + } + throw new SQLFakeParseException('Expected expected FOR or index list in index hint'); + } + + if ($next['value'] === 'FOR') { + $pointer++; + $next = $tokens[$pointer] ?? null; + if ($next === null) { + throw new SQLFakeParseException('Expected JOIN, ORDER BY, or GROUP BY after FOR in index hint'); + } else if ($next['value'] === 'JOIN') { + //this is fine + $pointer++; + $next = $tokens[$pointer] ?? null; + } else if (C\contains_key(keyset['GROUP', 'ORDER'], $next['value'])) { + $pointer++; + $next = $tokens[$pointer] ?? null; + if ($next === null || $next['value'] !== 'BY') { + throw new SQLFakeParseException('Expected BY in index hint after GROUP or ORDER'); + } + + $pointer++; + $next = $tokens[$pointer] ?? null; + } else { + throw new SQLFakeParseException('Expected JOIN, ORDER BY, or GROUP BY after FOR in index hint'); + } + } + + if ($next === null || $next['type'] !== TokenType::PAREN) { + // USE hint is allowed to stop at "USE INDEX" which means "use no indexes" + if ($hint_type === 'USE') { + $pointer--; + return $pointer; + } + throw new SQLFakeParseException('Expected index expression after index hint'); + } + + $closing_paren_pointer = SQLParser::findMatchingParen($pointer, $tokens); + $arg_tokens = Vec\slice($tokens, $pointer + 1, $closing_paren_pointer - $pointer - 1); + if (!C\count($arg_tokens)) { + throw new SQLFakeParseException('Expected at least one argument to index hint'); + } + $count = 0; + foreach ($arg_tokens as $arg) { + $count++; + if ($count % 2 === 1) { + if ($arg['type'] !== TokenType::IDENTIFIER) { + throw new SQLFakeParseException('Expected identifier in index hint'); + } + } else if ($arg['value'] !== ',') { + throw new SQLFakeParseException('Expected , or ) after index hint'); + } + } + + $pointer = $closing_paren_pointer; + + // you can have multiple index hints... so if the next token starts another index hint we can go back into this while loop + $next_pointer = $pointer + 1; + $next = $tokens[$next_pointer] ?? null; + + } + + return $pointer; + } + + /* + * Skip over lock hints, but might as well still syntax validate them while doing so + * Examples of index hints: + * FROM table1 WHERE ... FOR UPDATE + * FROM table1 WHERE ... LOCK IN SHARE MODE + */ + public static function skipLockHints(int $pointer, token_list $tokens): int { + $next_pointer = $pointer + 1; + $next = $tokens[$next_pointer] ?? null; + + if ($next !== null && $next['type'] === TokenType::RESERVED) { + if ($next['value'] === 'FOR') { + // skip over FOR UDPATE while validating it + $next_pointer++; + $next = $tokens[$next_pointer] ?? null; + if ($next === null) { + throw new SQLFakeParseException('Expected keyword after FOR'); + } + // skip over FOR UPDATE + if ($next['value'] === 'UPDATE') { + return $pointer + 2; + } + + throw new SQLFakeParseException("Unexpected keyword {$next['value']} after FOR"); + } else if ($next['value'] === 'LOCK') { + // skip over LOCK IN SHARE MODE while validating it + $expected = vec['IN', 'SHARE', 'MODE']; + foreach ($expected as $index => $keyword) { + $next = $tokens[$next_pointer + $index + 1] ?? null; + if ($next === null || $next['value'] !== $keyword) { + throw new SQLFakeParseException('Unexpected keyword near LOCK'); + } + } + return $pointer + 4; + } + } + + return $pointer; + } + + const keyset CLAUSES = keyset[ + 'SELECT', + 'FROM', + 'WHERE', + 'GROUP', + 'HAVING', + 'LIMIT', + 'ORDER', + 'UPDATE', + 'SET', + 'DELETE', + 'UNION', + 'EXCEPT', + 'INTERSECT', + 'INSERT', + 'VALUES', + ]; + + // left PAREN is not in here because it triggers special logic, so it has its own type of PAREN + const keyset SEPARATORS = keyset[ + ')', + ',', + ';', + ]; + + const keyset OPERATORS = keyset[ + 'INTERVAL', + 'COLLATE', + '!', + '~', + '^', + '*', + '/', + 'DIV', + '%', + 'MOD', + '-', + '+', + '<<', + '>>', + '&', + '|', + '=', + '<=>', + '>=', + '>', + '<=', + '<', + '<>', + '!=', + 'IS', + 'LIKE', + 'REGEXP', + 'IN', + 'EXISTS', + 'BETWEEN', + 'CASE', + 'WHEN', + 'THEN', + 'ELSE', + 'END', + 'NOT', + 'AND', + '&&', + 'XOR', + 'OR', + '||', + ]; + + const keyset RESERVED_WORDS = keyset[ + 'ASC', + 'DESC', + 'AS', + 'WITH', + 'ON', + 'OFFSET', + 'BY', + 'INTO', + 'ALL', + 'DISTINCT', + 'DISTINCTROW', + 'SQL_CALC_FOUND_ROWS', + 'HIGH_PRIORITY', + 'SQL_SMALL_RESULT', + 'SQL_BIG_RESULT', + 'SQL_BUFFER_RESULT', + 'SQL_CACHE', + 'SQL_NO_CACHE', + 'JOIN', + 'INNER', + 'OUTER', + 'LEFT', + 'RIGHT', + 'STRAIGHT_JOIN', + 'NATURAL', + 'USING', + 'CROSS', + 'USE', + 'IGNORE', + 'FORCE', + 'PARTITION', + 'ROLLUP', + 'INDEX', + 'KEY', + 'FOR', + 'LOCK', + 'DUPLICATE', + 'DELAYED', + 'LOW_PRIORITY', + 'HIGH_PRIORITY', + ]; } diff --git a/src/Parser/SelectParser.php b/src/Parser/SelectParser.php index 6ff3efe..46629fc 100644 --- a/src/Parser/SelectParser.php +++ b/src/Parser/SelectParser.php @@ -6,231 +6,231 @@ final class SelectParser { - const dict CLAUSE_ORDER = dict[ - 'SELECT' => 1, - 'FROM' => 2, - 'WHERE' => 3, - 'GROUP' => 4, - 'HAVING' => 5, - 'ORDER' => 6, - 'LIMIT' => 7, - ]; - - private string $currentClause = 'SELECT'; - public function __construct(private int $pointer, private token_list $tokens, private string $sql) {} - - public function parse(): (int, SelectQuery) { - // if we got here, the first token had better be a SELECT - $token = $this->tokens[$this->pointer] ?? null; - $incr = 0; - while ($token is nonnull && $token['type'] === TokenType::PAREN) { - $close = SQLParser::findMatchingParen($this->pointer, $this->tokens); - $this->tokens = Vec\slice($this->tokens, $this->pointer + 1, $close - $this->pointer - 1); - $this->pointer = 0; - $incr++; - $token = $this->tokens[$this->pointer + $incr] ?? null; - } - - // for every paren we took off, we took of two tokens - $incr *= 2; - - if ($this->tokens[$this->pointer]['value'] !== 'SELECT') { - throw new SQLFakeParseException('Parser error: expected SELECT'); - } - - $query = new SelectQuery($this->sql); - $this->pointer++; - $count = C\count($this->tokens); - - while ($this->pointer < $count) { - $token = $this->tokens[$this->pointer]; - - switch ($token['type']) { - case TokenType::NUMERIC_CONSTANT: - case TokenType::BOOLEAN_CONSTANT: - case TokenType::NULL_CONSTANT: - case TokenType::STRING_CONSTANT: - case TokenType::OPERATOR: - case TokenType::SQLFUNCTION: - case TokenType::IDENTIFIER: - case TokenType::PAREN: - // we should only see these things when we're in the SELECT clause - // all other clauses should parse their own tokens - // also check that there has been a delimiter since the last expression if we're adding a new one now - if ($this->currentClause !== 'SELECT') { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - if ($query->needsSeparator) { - // we just had an expression and no comma yet. if this is a string or identifier, it must be an alias like "SELECT 1 foo" - if ( - C\contains_key(keyset[TokenType::IDENTIFIER, TokenType::STRING_CONSTANT], $token['type']) && - !$query->mostRecentHasAlias - ) { - $query->aliasRecentExpression($token['value']); - break; - } else { - // if the new token isn't an identifier, or the most recent expression had an alias, this is bogus - throw new SQLFakeParseException("Expected comma between expressions near {$token['value']}"); - } - } - $expression_parser = new ExpressionParser($this->tokens, $this->pointer - 1); - $start = $this->pointer; - list($this->pointer, $expression) = $expression_parser->buildWithPointer(); - - // we should "name" the column based on the entire expression (it may get overwritten by an alias) - // this is actually important, because otherwise if two expressions exist in the select they might not both be in the results if we don't give them unique names - // since the result rows are keyed by strings. only do this for non-scalar expressions - if (!$expression is ColumnExpression && !$expression is ConstantExpression) { - $name = ''; - $slice = Vec\slice($this->tokens, $start, $this->pointer - $start + 1); - foreach ($slice as $t) { - $name .= $t['raw']; - } - $expression->name = Str\trim($name); - } - - $query->addSelectExpression($expression); - break; - case TokenType::SEPARATOR: - if ($token['value'] === ',') { - if (!$query->needsSeparator) { - throw new SQLFakeParseException('Unexpected ,'); - } - $query->needsSeparator = false; - } else if ($token['value'] === ';') { - // this should be the final token. if it's not, throw. otherwise, return - if ($this->pointer !== $count - 1) { - throw new SQLFakeParseException('Unexpected tokens after semicolon'); - } - return tuple($this->pointer + $incr, $query); - } else { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - break; - case TokenType::CLAUSE: - // make sure clauses are in order - if ( - C\contains_key(self::CLAUSE_ORDER, $token['value']) && - self::CLAUSE_ORDER[$this->currentClause] >= self::CLAUSE_ORDER[$token['value']] - ) { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - $this->currentClause = $token['value']; - switch ($token['value']) { - case 'FROM': - $from = new FromParser($this->pointer, $this->tokens); - list($this->pointer, $fromClause) = $from->parse(); - $query->fromClause = $fromClause; - break; - case 'WHERE': - $expression_parser = new ExpressionParser($this->tokens, $this->pointer); - list($this->pointer, $expression) = $expression_parser->buildWithPointer(); - $query->whereClause = $expression; - break; - case 'GROUP': - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - $expressions = vec[]; - if ($next === null || $next['value'] !== 'BY') { - throw new SQLFakeParseException('Expected BY after GROUP'); - } - - while (true) { - $expression_parser = new ExpressionParser($this->tokens, $this->pointer); - $expression_parser->setSelectExpressions($query->selectExpressions); - list($this->pointer, $expression) = $expression_parser->buildWithPointer(); - - // group by and order by support POSITIONAL operators such as "GROUP BY 1". And constants aren't supported. - // so if a constant comes back in the list, it has to be a positional operator - // and the columns are 1-indexed in MySQL terms so we subtract one from the $position arg to find the right expression - if ($expression is ConstantExpression) { - $position = (int)($expression->value) - 1; - - $expression = $query->selectExpressions[$position] ?? null; - if ($expression === null) { - throw new SQLFakeParseException("Invalid positional reference $position in GROUP BY"); - } - } - - $expressions[] = $expression; - $next = $this->tokens[$this->pointer + 1] ?? null; - // skip over commas and continue the processing, but if it's any other token break out of the loop - if ($next === null || $next['value'] !== ',') { - break; - } - $this->pointer++; - } - - $query->groupBy = $expressions; - break; - case 'ORDER': - $p = new OrderByParser($this->pointer, $this->tokens, $query->selectExpressions); - list($this->pointer, $query->orderBy) = $p->parse(); - break; - case 'HAVING': - // same as where, except we add select expressions here - $expression_parser = new ExpressionParser($this->tokens, $this->pointer); - $expression_parser->setSelectExpressions($query->selectExpressions); - list($this->pointer, $expression) = $expression_parser->buildWithPointer(); - $query->havingClause = $expression; - break; - case 'LIMIT': - $p = new LimitParser($this->pointer, $this->tokens); - list($this->pointer, $query->limitClause) = $p->parse(); - break; - case 'UNION': - case 'EXCEPT': - case 'INTERSECT': - // return control back to parent, so that if we are at top level we can add this and otherwise not - return tuple($this->pointer + $incr, $query); - break; - default: - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - // after adding a clause like FROM or WHERE, skip over any locking hints - $this->pointer = SQLParser::skipLockHints($this->pointer, $this->tokens); - break; - case TokenType::RESERVED: - switch ($token['value']) { - case 'AS': - // seek forward for identifier, then add alias to most recent expression - $this->pointer++; - $next = $this->tokens[$this->pointer] ?? null; - if ( - $next === null || - !C\contains_key(keyset[TokenType::IDENTIFIER, TokenType::STRING_CONSTANT], $next['type']) - ) { - throw new SQLFakeParseException('Expected alias name after AS'); - } - $query->aliasRecentExpression($next['value']); - break; - case 'DISTINCT': - case 'DISTINCTROW': - case 'ALL': - case 'HIGH_PRIORITY': - case 'SQL_CALC_FOUND_ROWS': - case 'HIGH_PRIORITY': - case 'SQL_SMALL_RESULT': - case 'SQL_BIG_RESULT': - case 'SQL_BUFFER_RESULT': - case 'SQL_CACHE': - case 'SQL_NO_CACHE': - // DISTINCTROW is an alias for DISTINCT - if ($token['value'] === 'DISTINCTROW') { - $token['value'] = 'DISTINCT'; - } - $query->addOption($token['value']); - break; - default: - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - break; - } - - $this->pointer++; - } - - // check if query is well formed here? well basically it just has to have at least one expression in the SELECT clause - return tuple($this->pointer + $incr, $query); - } + const dict CLAUSE_ORDER = dict[ + 'SELECT' => 1, + 'FROM' => 2, + 'WHERE' => 3, + 'GROUP' => 4, + 'HAVING' => 5, + 'ORDER' => 6, + 'LIMIT' => 7, + ]; + + private string $currentClause = 'SELECT'; + public function __construct(private int $pointer, private token_list $tokens, private string $sql) {} + + public function parse(): (int, SelectQuery) { + // if we got here, the first token had better be a SELECT + $token = $this->tokens[$this->pointer] ?? null; + $incr = 0; + while ($token is nonnull && $token['type'] === TokenType::PAREN) { + $close = SQLParser::findMatchingParen($this->pointer, $this->tokens); + $this->tokens = Vec\slice($this->tokens, $this->pointer + 1, $close - $this->pointer - 1); + $this->pointer = 0; + $incr++; + $token = $this->tokens[$this->pointer + $incr] ?? null; + } + + // for every paren we took off, we took of two tokens + $incr *= 2; + + if ($this->tokens[$this->pointer]['value'] !== 'SELECT') { + throw new SQLFakeParseException('Parser error: expected SELECT'); + } + + $query = new SelectQuery($this->sql); + $this->pointer++; + $count = C\count($this->tokens); + + while ($this->pointer < $count) { + $token = $this->tokens[$this->pointer]; + + switch ($token['type']) { + case TokenType::NUMERIC_CONSTANT: + case TokenType::BOOLEAN_CONSTANT: + case TokenType::NULL_CONSTANT: + case TokenType::STRING_CONSTANT: + case TokenType::OPERATOR: + case TokenType::SQLFUNCTION: + case TokenType::IDENTIFIER: + case TokenType::PAREN: + // we should only see these things when we're in the SELECT clause + // all other clauses should parse their own tokens + // also check that there has been a delimiter since the last expression if we're adding a new one now + if ($this->currentClause !== 'SELECT') { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + if ($query->needsSeparator) { + // we just had an expression and no comma yet. if this is a string or identifier, it must be an alias like "SELECT 1 foo" + if ( + C\contains_key(keyset[TokenType::IDENTIFIER, TokenType::STRING_CONSTANT], $token['type']) && + !$query->mostRecentHasAlias + ) { + $query->aliasRecentExpression($token['value']); + break; + } else { + // if the new token isn't an identifier, or the most recent expression had an alias, this is bogus + throw new SQLFakeParseException("Expected comma between expressions near {$token['value']}"); + } + } + $expression_parser = new ExpressionParser($this->tokens, $this->pointer - 1); + $start = $this->pointer; + list($this->pointer, $expression) = $expression_parser->buildWithPointer(); + + // we should "name" the column based on the entire expression (it may get overwritten by an alias) + // this is actually important, because otherwise if two expressions exist in the select they might not both be in the results if we don't give them unique names + // since the result rows are keyed by strings. only do this for non-scalar expressions + if (!$expression is ColumnExpression && !$expression is ConstantExpression) { + $name = ''; + $slice = Vec\slice($this->tokens, $start, $this->pointer - $start + 1); + foreach ($slice as $t) { + $name .= $t['raw']; + } + $expression->name = Str\trim($name); + } + + $query->addSelectExpression($expression); + break; + case TokenType::SEPARATOR: + if ($token['value'] === ',') { + if (!$query->needsSeparator) { + throw new SQLFakeParseException('Unexpected ,'); + } + $query->needsSeparator = false; + } else if ($token['value'] === ';') { + // this should be the final token. if it's not, throw. otherwise, return + if ($this->pointer !== $count - 1) { + throw new SQLFakeParseException('Unexpected tokens after semicolon'); + } + return tuple($this->pointer + $incr, $query); + } else { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + break; + case TokenType::CLAUSE: + // make sure clauses are in order + if ( + C\contains_key(self::CLAUSE_ORDER, $token['value']) && + self::CLAUSE_ORDER[$this->currentClause] >= self::CLAUSE_ORDER[$token['value']] + ) { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + $this->currentClause = $token['value']; + switch ($token['value']) { + case 'FROM': + $from = new FromParser($this->pointer, $this->tokens); + list($this->pointer, $fromClause) = $from->parse(); + $query->fromClause = $fromClause; + break; + case 'WHERE': + $expression_parser = new ExpressionParser($this->tokens, $this->pointer); + list($this->pointer, $expression) = $expression_parser->buildWithPointer(); + $query->whereClause = $expression; + break; + case 'GROUP': + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + $expressions = vec[]; + if ($next === null || $next['value'] !== 'BY') { + throw new SQLFakeParseException('Expected BY after GROUP'); + } + + while (true) { + $expression_parser = new ExpressionParser($this->tokens, $this->pointer); + $expression_parser->setSelectExpressions($query->selectExpressions); + list($this->pointer, $expression) = $expression_parser->buildWithPointer(); + + // group by and order by support POSITIONAL operators such as "GROUP BY 1". And constants aren't supported. + // so if a constant comes back in the list, it has to be a positional operator + // and the columns are 1-indexed in MySQL terms so we subtract one from the $position arg to find the right expression + if ($expression is ConstantExpression) { + $position = (int)($expression->value) - 1; + + $expression = $query->selectExpressions[$position] ?? null; + if ($expression === null) { + throw new SQLFakeParseException("Invalid positional reference $position in GROUP BY"); + } + } + + $expressions[] = $expression; + $next = $this->tokens[$this->pointer + 1] ?? null; + // skip over commas and continue the processing, but if it's any other token break out of the loop + if ($next === null || $next['value'] !== ',') { + break; + } + $this->pointer++; + } + + $query->groupBy = $expressions; + break; + case 'ORDER': + $p = new OrderByParser($this->pointer, $this->tokens, $query->selectExpressions); + list($this->pointer, $query->orderBy) = $p->parse(); + break; + case 'HAVING': + // same as where, except we add select expressions here + $expression_parser = new ExpressionParser($this->tokens, $this->pointer); + $expression_parser->setSelectExpressions($query->selectExpressions); + list($this->pointer, $expression) = $expression_parser->buildWithPointer(); + $query->havingClause = $expression; + break; + case 'LIMIT': + $p = new LimitParser($this->pointer, $this->tokens); + list($this->pointer, $query->limitClause) = $p->parse(); + break; + case 'UNION': + case 'EXCEPT': + case 'INTERSECT': + // return control back to parent, so that if we are at top level we can add this and otherwise not + return tuple($this->pointer + $incr, $query); + break; + default: + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + // after adding a clause like FROM or WHERE, skip over any locking hints + $this->pointer = SQLParser::skipLockHints($this->pointer, $this->tokens); + break; + case TokenType::RESERVED: + switch ($token['value']) { + case 'AS': + // seek forward for identifier, then add alias to most recent expression + $this->pointer++; + $next = $this->tokens[$this->pointer] ?? null; + if ( + $next === null || + !C\contains_key(keyset[TokenType::IDENTIFIER, TokenType::STRING_CONSTANT], $next['type']) + ) { + throw new SQLFakeParseException('Expected alias name after AS'); + } + $query->aliasRecentExpression($next['value']); + break; + case 'DISTINCT': + case 'DISTINCTROW': + case 'ALL': + case 'HIGH_PRIORITY': + case 'SQL_CALC_FOUND_ROWS': + case 'HIGH_PRIORITY': + case 'SQL_SMALL_RESULT': + case 'SQL_BIG_RESULT': + case 'SQL_BUFFER_RESULT': + case 'SQL_CACHE': + case 'SQL_NO_CACHE': + // DISTINCTROW is an alias for DISTINCT + if ($token['value'] === 'DISTINCTROW') { + $token['value'] = 'DISTINCT'; + } + $query->addOption($token['value']); + break; + default: + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + break; + } + + $this->pointer++; + } + + // check if query is well formed here? well basically it just has to have at least one expression in the SELECT clause + return tuple($this->pointer + $incr, $query); + } } diff --git a/src/Parser/SetParser.php b/src/Parser/SetParser.php index da25d8f..1dbce05 100644 --- a/src/Parser/SetParser.php +++ b/src/Parser/SetParser.php @@ -7,80 +7,80 @@ // process the SET clause of an UPDATE, or the UPDATE portion of INSERT .. ON DUPLICATE KEY UPDATE final class SetParser { - public function __construct(private int $pointer, private token_list $tokens) {} + public function __construct(private int $pointer, private token_list $tokens) {} - public function parse(bool $skip_set = false): (int, vec) { + public function parse(bool $skip_set = false): (int, vec) { - // if we got here, the first token had better be a SET - if (!$skip_set && $this->tokens[$this->pointer]['value'] !== 'SET') { - throw new SQLFakeParseException('Parser error: expected SET'); - } - $expressions = vec[]; - $this->pointer++; - $count = C\count($this->tokens); + // if we got here, the first token had better be a SET + if (!$skip_set && $this->tokens[$this->pointer]['value'] !== 'SET') { + throw new SQLFakeParseException('Parser error: expected SET'); + } + $expressions = vec[]; + $this->pointer++; + $count = C\count($this->tokens); - $needs_comma = false; - $end_of_set = false; - while ($this->pointer < $count) { - $token = $this->tokens[$this->pointer]; + $needs_comma = false; + $end_of_set = false; + while ($this->pointer < $count) { + $token = $this->tokens[$this->pointer]; - switch ($token['type']) { - case TokenType::NUMERIC_CONSTANT: - case TokenType::BOOLEAN_CONSTANT: - case TokenType::STRING_CONSTANT: - case TokenType::NULL_CONSTANT: - case TokenType::OPERATOR: - case TokenType::SQLFUNCTION: - case TokenType::IDENTIFIER: - case TokenType::PAREN: - if ($needs_comma) { - throw new SQLFakeParseException('Expected , between expressions in SET clause'); - } - $expression_parser = new ExpressionParser($this->tokens, $this->pointer - 1); - $this->pointer; - list($this->pointer, $expression) = $expression_parser->buildWithPointer(); + switch ($token['type']) { + case TokenType::NUMERIC_CONSTANT: + case TokenType::BOOLEAN_CONSTANT: + case TokenType::STRING_CONSTANT: + case TokenType::NULL_CONSTANT: + case TokenType::OPERATOR: + case TokenType::SQLFUNCTION: + case TokenType::IDENTIFIER: + case TokenType::PAREN: + if ($needs_comma) { + throw new SQLFakeParseException('Expected , between expressions in SET clause'); + } + $expression_parser = new ExpressionParser($this->tokens, $this->pointer - 1); + $this->pointer; + list($this->pointer, $expression) = $expression_parser->buildWithPointer(); - // the only valid kind of expression in a SET is "foo = bar" - if (!$expression is BinaryOperatorExpression || $expression->operator !== Operator::EQUALS) { - throw new SQLFakeParseException('Failed parsing SET clause: unexpected expression'); - } + // the only valid kind of expression in a SET is "foo = bar" + if (!$expression is BinaryOperatorExpression || $expression->operator !== Operator::EQUALS) { + throw new SQLFakeParseException('Failed parsing SET clause: unexpected expression'); + } - if (!$expression->left is ColumnExpression) { - throw new SQLFakeParseException('Left side of SET clause must be a column reference'); - } + if (!$expression->left is ColumnExpression) { + throw new SQLFakeParseException('Left side of SET clause must be a column reference'); + } - $expressions[] = $expression; - $needs_comma = true; - break; - case TokenType::SEPARATOR: - if ($token['value'] === ',') { - if (!$needs_comma) { - throw new SQLFakeParseException('Unexpected ,'); - } - $needs_comma = false; - } else { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - break; - case TokenType::CLAUSE: - // return once we get to the next clause - $end_of_set = true; - break; - default: - throw new SQLFakeParseException("Unexpected {$token['value']} in SET"); - } + $expressions[] = $expression; + $needs_comma = true; + break; + case TokenType::SEPARATOR: + if ($token['value'] === ',') { + if (!$needs_comma) { + throw new SQLFakeParseException('Unexpected ,'); + } + $needs_comma = false; + } else { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + break; + case TokenType::CLAUSE: + // return once we get to the next clause + $end_of_set = true; + break; + default: + throw new SQLFakeParseException("Unexpected {$token['value']} in SET"); + } - if ($end_of_set) { - break; - } + if ($end_of_set) { + break; + } - $this->pointer++; - } + $this->pointer++; + } - if (!C\count($expressions)) { - throw new SQLFakeParseException('Empty SET clause'); - } + if (!C\count($expressions)) { + throw new SQLFakeParseException('Empty SET clause'); + } - return tuple($this->pointer - 1, $expressions); - } + return tuple($this->pointer - 1, $expressions); + } } diff --git a/src/Parser/UpdateParser.php b/src/Parser/UpdateParser.php index 8d00bca..5ad0221 100644 --- a/src/Parser/UpdateParser.php +++ b/src/Parser/UpdateParser.php @@ -6,98 +6,98 @@ final class UpdateParser { - const dict CLAUSE_ORDER = dict[ - 'UPDATE' => 1, - 'SET' => 2, - 'WHERE' => 3, - 'ORDER' => 4, - 'LIMIT' => 5, - ]; - - private string $current_clause = 'UPDATE'; - private int $pointer = 0; - - public function __construct(private token_list $tokens, private string $sql) {} - - public function parse(): UpdateQuery { - - // if we got here, the first token had better be a UPDATE - if ($this->tokens[$this->pointer]['value'] !== 'UPDATE') { - throw new SQLFakeParseException('Parser error: expected UPDATE'); - } - $this->pointer++; - - // IGNORE can come next and indicates duplicate keys should be ignored - $ignore_dupes = false; - if ($this->tokens[$this->pointer]['value'] === 'IGNORE') { - $ignore_dupes = true; - $this->pointer++; - } - - $count = C\count($this->tokens); - - // next token has to be a table name - $token = $this->tokens[$this->pointer]; - if ($token === null || $token['type'] !== TokenType::IDENTIFIER) { - throw new SQLFakeParseException('Expected table name after UPDATE'); - } - - $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); - - $table = shape('name' => $token['value'], 'join_type' => JoinType::JOIN); - - $query = new UpdateQuery($table, $this->sql, $ignore_dupes); - - $this->pointer++; - - while ($this->pointer < $count) { - $token = $this->tokens[$this->pointer]; - - switch ($token['type']) { - case TokenType::CLAUSE: - // make sure clauses are in order - if ( - C\contains_key(self::CLAUSE_ORDER, $token['value']) && - self::CLAUSE_ORDER[$this->current_clause] >= self::CLAUSE_ORDER[$token['value']] - ) { - throw new SQLFakeParseException("Unexpected clause {$token['value']}"); - } - $this->current_clause = $token['value']; - switch ($token['value']) { - case 'WHERE': - $expression_parser = new ExpressionParser($this->tokens, $this->pointer); - list($this->pointer, $expression) = $expression_parser->buildWithPointer(); - $query->whereClause = $expression; - break; - case 'ORDER': - $p = new OrderByParser($this->pointer, $this->tokens); - list($this->pointer, $query->orderBy) = $p->parse(); - break; - case 'LIMIT': - $p = new LimitParser($this->pointer, $this->tokens); - list($this->pointer, $query->limitClause) = $p->parse(); - break; - case 'SET': - $p = new SetParser($this->pointer, $this->tokens); - list($this->pointer, $query->setClause) = $p->parse(); - break; - default: - throw new SQLFakeParseException("Unexpected clause {$token['value']}"); - } - break; - case TokenType::SEPARATOR: - // a semicolon to end the query is valid, but nothing else is in this context - if ($token['value'] !== ';') { - throw new SQLFakeParseException("Unexpected {$token['value']}"); - } - break; - default: - throw new SQLFakeParseException("Unexpected token {$token['value']}"); - } - - $this->pointer++; - } - - return $query; - } + const dict CLAUSE_ORDER = dict[ + 'UPDATE' => 1, + 'SET' => 2, + 'WHERE' => 3, + 'ORDER' => 4, + 'LIMIT' => 5, + ]; + + private string $current_clause = 'UPDATE'; + private int $pointer = 0; + + public function __construct(private token_list $tokens, private string $sql) {} + + public function parse(): UpdateQuery { + + // if we got here, the first token had better be a UPDATE + if ($this->tokens[$this->pointer]['value'] !== 'UPDATE') { + throw new SQLFakeParseException('Parser error: expected UPDATE'); + } + $this->pointer++; + + // IGNORE can come next and indicates duplicate keys should be ignored + $ignore_dupes = false; + if ($this->tokens[$this->pointer]['value'] === 'IGNORE') { + $ignore_dupes = true; + $this->pointer++; + } + + $count = C\count($this->tokens); + + // next token has to be a table name + $token = $this->tokens[$this->pointer]; + if ($token === null || $token['type'] !== TokenType::IDENTIFIER) { + throw new SQLFakeParseException('Expected table name after UPDATE'); + } + + $this->pointer = SQLParser::skipIndexHints($this->pointer, $this->tokens); + + $table = shape('name' => $token['value'], 'join_type' => JoinType::JOIN); + + $query = new UpdateQuery($table, $this->sql, $ignore_dupes); + + $this->pointer++; + + while ($this->pointer < $count) { + $token = $this->tokens[$this->pointer]; + + switch ($token['type']) { + case TokenType::CLAUSE: + // make sure clauses are in order + if ( + C\contains_key(self::CLAUSE_ORDER, $token['value']) && + self::CLAUSE_ORDER[$this->current_clause] >= self::CLAUSE_ORDER[$token['value']] + ) { + throw new SQLFakeParseException("Unexpected clause {$token['value']}"); + } + $this->current_clause = $token['value']; + switch ($token['value']) { + case 'WHERE': + $expression_parser = new ExpressionParser($this->tokens, $this->pointer); + list($this->pointer, $expression) = $expression_parser->buildWithPointer(); + $query->whereClause = $expression; + break; + case 'ORDER': + $p = new OrderByParser($this->pointer, $this->tokens); + list($this->pointer, $query->orderBy) = $p->parse(); + break; + case 'LIMIT': + $p = new LimitParser($this->pointer, $this->tokens); + list($this->pointer, $query->limitClause) = $p->parse(); + break; + case 'SET': + $p = new SetParser($this->pointer, $this->tokens); + list($this->pointer, $query->setClause) = $p->parse(); + break; + default: + throw new SQLFakeParseException("Unexpected clause {$token['value']}"); + } + break; + case TokenType::SEPARATOR: + // a semicolon to end the query is valid, but nothing else is in this context + if ($token['value'] !== ';') { + throw new SQLFakeParseException("Unexpected {$token['value']}"); + } + break; + default: + throw new SQLFakeParseException("Unexpected token {$token['value']}"); + } + + $this->pointer++; + } + + return $query; + } } diff --git a/src/Query/DeleteQuery.php b/src/Query/DeleteQuery.php index 0e000eb..5f8b648 100644 --- a/src/Query/DeleteQuery.php +++ b/src/Query/DeleteQuery.php @@ -5,43 +5,43 @@ use namespace HH\Lib\{C, Keyset, Vec}; final class DeleteQuery extends Query { - public ?from_table $fromClause = null; - - public function __construct(public string $sql) {} - - public function execute(AsyncMysqlConnection $conn): int { - $this->fromClause as nonnull; - list($database, $table_name) = Query::parseTableName($conn, $this->fromClause['name']); - $data = $conn->getServer()->getTable($database, $table_name) ?? vec[]; - Metrics::trackQuery(QueryType::DELETE, $conn->getServer()->name, $table_name, $this->sql); - - return $this->applyWhere($conn, $data) - |> $this->applyOrderBy($conn, $$) - |> $this->applyLimit($$) - |> $this->applyDelete($conn, $database, $table_name, $$, $data); - } - - /** - * Delete rows after all filtering clauses, and return the number of rows deleted - */ - protected function applyDelete( - AsyncMysqlConnection $conn, - string $database, - string $table_name, - dataset $filtered_rows, - dataset $original_table, - ): int { - - // if this isn't a dict keyed by the original ids in the row, it could delete the wrong rows - $filtered_rows as dict<_, _>; - - $rows_to_delete = Keyset\keys($filtered_rows); - $remaining_rows = - Vec\filter_with_key($original_table, ($row_num, $_) ==> !C\contains_key($rows_to_delete, $row_num)); - $rows_affected = C\count($original_table) - C\count($remaining_rows); - - // write it back to the database - $conn->getServer()->saveTable($database, $table_name, $remaining_rows); - return $rows_affected; - } + public ?from_table $fromClause = null; + + public function __construct(public string $sql) {} + + public function execute(AsyncMysqlConnection $conn): int { + $this->fromClause as nonnull; + list($database, $table_name) = Query::parseTableName($conn, $this->fromClause['name']); + $data = $conn->getServer()->getTable($database, $table_name) ?? vec[]; + Metrics::trackQuery(QueryType::DELETE, $conn->getServer()->name, $table_name, $this->sql); + + return $this->applyWhere($conn, $data) + |> $this->applyOrderBy($conn, $$) + |> $this->applyLimit($$) + |> $this->applyDelete($conn, $database, $table_name, $$, $data); + } + + /** + * Delete rows after all filtering clauses, and return the number of rows deleted + */ + protected function applyDelete( + AsyncMysqlConnection $conn, + string $database, + string $table_name, + dataset $filtered_rows, + dataset $original_table, + ): int { + + // if this isn't a dict keyed by the original ids in the row, it could delete the wrong rows + $filtered_rows as dict<_, _>; + + $rows_to_delete = Keyset\keys($filtered_rows); + $remaining_rows = + Vec\filter_with_key($original_table, ($row_num, $_) ==> !C\contains_key($rows_to_delete, $row_num)); + $rows_affected = C\count($original_table) - C\count($remaining_rows); + + // write it back to the database + $conn->getServer()->saveTable($database, $table_name, $remaining_rows); + return $rows_affected; + } } diff --git a/src/Query/FromClause.php b/src/Query/FromClause.php index abb0bdc..aa98c78 100644 --- a/src/Query/FromClause.php +++ b/src/Query/FromClause.php @@ -12,122 +12,122 @@ */ final class FromClause { - public vec $tables = vec[]; - public bool $mostRecentHasAlias = false; - - public function addTable(from_table $table): void { - $this->tables[] = $table; - $this->mostRecentHasAlias = false; - } - - public function aliasRecentExpression(string $name): void { - $k = C\last_key($this->tables); - if ($k === null || $this->mostRecentHasAlias) { - throw new SQLFakeParseException('Unexpected AS'); - } - $this->tables[$k]['alias'] = $name; - $this->mostRecentHasAlias = true; - } - - /** - * The FROM clause of the query gets processed first, retrieving data from tables, executing subqueries, and handling joins - * This is also where we build up the $columns list which is commonly used throughout the entire library to map column references to indexes in this dataset - * @reviewer, we don't build up the $columns, since the variable is unused... - */ - public function process(AsyncMysqlConnection $conn, string $sql): dataset { - - $data = vec[]; - $is_first_table = true; - - foreach ($this->tables as $table) { - $schema = null; - if (Shapes::keyExists($table, 'subquery')) { - $res = $table['subquery']->evaluate(dict[], $conn); - $name = $table['name']; - } else { - $table_name = $table['name']; - - list($database, $table_name) = Query::parseTableName($conn, $table_name); - - // TODO if different database, should $name have that in it as well for other things like column references? probably, right? - $name = $table['alias'] ?? $table_name; - $schema = QueryContext::getSchema($database, $table_name); - if ($schema === null && QueryContext::$strictSchemaMode) { - throw new SQLFakeRuntimeException("Table $table_name not found in schema and strict mode is enabled"); - } - - $res = $conn->getServer()->getTable($database, $table_name); - - if ($res === null) { - $res = vec[]; - } - } - - invariant($res is KeyedContainer<_, _>, 'evaluated result of SubqueryExpression must be dataset'); - - $new_dataset = vec[]; - if ($schema is nonnull && QueryContext::$strictSchemaMode) { - foreach ($res as $row) { - $row as dict<_, _>; - $m = dict[]; - foreach ($row as $field => $val) { - $m["{$name}.{$field}"] = $val; - } - $new_dataset[] = $m; - } - } else if ($schema is nonnull) { - // if schema is set, order the fields in the right order on each row - $ordered_fields = keyset[]; - foreach ($schema['fields'] as $field) { - $ordered_fields[] = $field['name']; - } - - foreach ($res as $row) { - invariant($row is dict<_, _>, 'each item in evaluated result of SubqueryExpression must be row'); - - $m = dict[]; - foreach ($ordered_fields as $field) { - if (!C\contains_key($row, $field)) { - continue; - } - $m["{$name}.{$field}"] = $row[$field]; - } - $new_dataset[] = $m; - } - } else { - foreach ($res as $row) { - invariant($row is dict<_, _>, 'each item in evaluated result of SubqueryExpression must be row'); - - $m = dict[]; - foreach ($row as $key => $val) { - $m["{$name}.{$key}"] = $val; - } - $new_dataset[] = $m; - } - } - - if ($data || !$is_first_table) { - // do the join here. based on join type, pass in $data and $res to filter. and aliases - $data = JoinProcessor::process( - $conn, - $data, - $new_dataset, - $name, - $table['join_type'], - $table['join_operator'] ?? null, - $table['join_expression'] ?? null, - $schema, - ); - } else { - $data = $new_dataset; - } - - if ($is_first_table) { - Metrics::trackQuery(QueryType::SELECT, $conn->getServer()->name, $name, $sql); - $is_first_table = false; - } - } - - return $data; - } + public vec $tables = vec[]; + public bool $mostRecentHasAlias = false; + + public function addTable(from_table $table): void { + $this->tables[] = $table; + $this->mostRecentHasAlias = false; + } + + public function aliasRecentExpression(string $name): void { + $k = C\last_key($this->tables); + if ($k === null || $this->mostRecentHasAlias) { + throw new SQLFakeParseException('Unexpected AS'); + } + $this->tables[$k]['alias'] = $name; + $this->mostRecentHasAlias = true; + } + + /** + * The FROM clause of the query gets processed first, retrieving data from tables, executing subqueries, and handling joins + * This is also where we build up the $columns list which is commonly used throughout the entire library to map column references to indexes in this dataset + * @reviewer, we don't build up the $columns, since the variable is unused... + */ + public function process(AsyncMysqlConnection $conn, string $sql): dataset { + + $data = vec[]; + $is_first_table = true; + + foreach ($this->tables as $table) { + $schema = null; + if (Shapes::keyExists($table, 'subquery')) { + $res = $table['subquery']->evaluate(dict[], $conn); + $name = $table['name']; + } else { + $table_name = $table['name']; + + list($database, $table_name) = Query::parseTableName($conn, $table_name); + + // TODO if different database, should $name have that in it as well for other things like column references? probably, right? + $name = $table['alias'] ?? $table_name; + $schema = QueryContext::getSchema($database, $table_name); + if ($schema === null && QueryContext::$strictSchemaMode) { + throw new SQLFakeRuntimeException("Table $table_name not found in schema and strict mode is enabled"); + } + + $res = $conn->getServer()->getTable($database, $table_name); + + if ($res === null) { + $res = vec[]; + } + } + + invariant($res is KeyedContainer<_, _>, 'evaluated result of SubqueryExpression must be dataset'); + + $new_dataset = vec[]; + if ($schema is nonnull && QueryContext::$strictSchemaMode) { + foreach ($res as $row) { + $row as dict<_, _>; + $m = dict[]; + foreach ($row as $field => $val) { + $m["{$name}.{$field}"] = $val; + } + $new_dataset[] = $m; + } + } else if ($schema is nonnull) { + // if schema is set, order the fields in the right order on each row + $ordered_fields = keyset[]; + foreach ($schema['fields'] as $field) { + $ordered_fields[] = $field['name']; + } + + foreach ($res as $row) { + invariant($row is dict<_, _>, 'each item in evaluated result of SubqueryExpression must be row'); + + $m = dict[]; + foreach ($ordered_fields as $field) { + if (!C\contains_key($row, $field)) { + continue; + } + $m["{$name}.{$field}"] = $row[$field]; + } + $new_dataset[] = $m; + } + } else { + foreach ($res as $row) { + invariant($row is dict<_, _>, 'each item in evaluated result of SubqueryExpression must be row'); + + $m = dict[]; + foreach ($row as $key => $val) { + $m["{$name}.{$key}"] = $val; + } + $new_dataset[] = $m; + } + } + + if ($data || !$is_first_table) { + // do the join here. based on join type, pass in $data and $res to filter. and aliases + $data = JoinProcessor::process( + $conn, + $data, + $new_dataset, + $name, + $table['join_type'], + $table['join_operator'] ?? null, + $table['join_expression'] ?? null, + $schema, + ); + } else { + $data = $new_dataset; + } + + if ($is_first_table) { + Metrics::trackQuery(QueryType::SELECT, $conn->getServer()->name, $name, $sql); + $is_first_table = false; + } + } + + return $data; + } } diff --git a/src/Query/InsertQuery.php b/src/Query/InsertQuery.php index dddebc5..0b072ed 100644 --- a/src/Query/InsertQuery.php +++ b/src/Query/InsertQuery.php @@ -6,81 +6,81 @@ final class InsertQuery extends Query { - public function __construct(public string $table, public string $sql, public bool $ignoreDupes) {} + public function __construct(public string $table, public string $sql, public bool $ignoreDupes) {} - public vec $updateExpressions = vec[]; - public vec $insertColumns = vec[]; - public vec> $values = vec[]; + public vec $updateExpressions = vec[]; + public vec $insertColumns = vec[]; + public vec> $values = vec[]; - /** - * Insert rows, with validation - * Returns number of rows affected - */ - public function execute(AsyncMysqlConnection $conn): int { - list($database, $table_name) = Query::parseTableName($conn, $this->table); - $table = $conn->getServer()->getTable($database, $table_name) ?? vec[]; + /** + * Insert rows, with validation + * Returns number of rows affected + */ + public function execute(AsyncMysqlConnection $conn): int { + list($database, $table_name) = Query::parseTableName($conn, $this->table); + $table = $conn->getServer()->getTable($database, $table_name) ?? vec[]; - Metrics::trackQuery(QueryType::INSERT, $conn->getServer()->name, $table_name, $this->sql); + Metrics::trackQuery(QueryType::INSERT, $conn->getServer()->name, $table_name, $this->sql); - $schema = QueryContext::getSchema($database, $table_name); - if ($schema === null && QueryContext::$strictSchemaMode) { - throw new SQLFakeRuntimeException("Table $table_name not found in schema and strict mode is enabled"); - } + $schema = QueryContext::getSchema($database, $table_name); + if ($schema === null && QueryContext::$strictSchemaMode) { + throw new SQLFakeRuntimeException("Table $table_name not found in schema and strict mode is enabled"); + } - $rows_affected = 0; - foreach ($this->values as $value_list) { - $row = dict[]; - foreach ($this->insertColumns as $key => $col) { - $row[$col] = $value_list[$key]->evaluate(dict[], $conn); - } + $rows_affected = 0; + foreach ($this->values as $value_list) { + $row = dict[]; + foreach ($this->insertColumns as $key => $col) { + $row[$col] = $value_list[$key]->evaluate(dict[], $conn); + } - // can't enforce uniqueness or defaults if there is no schema available - if ($schema === null) { - $table[] = $row; - $rows_affected++; - continue; - } + // can't enforce uniqueness or defaults if there is no schema available + if ($schema === null) { + $table[] = $row; + $rows_affected++; + continue; + } - // ensure all fields are present with appropriate types and default values - // throw for nonexistent fields - $row = DataIntegrity::coerceToSchema($row, $schema); + // ensure all fields are present with appropriate types and default values + // throw for nonexistent fields + $row = DataIntegrity::coerceToSchema($row, $schema); - // check for unique key violations - $unique_key_violation = DataIntegrity::checkUniqueConstraints($table, $row, $schema); - if ($unique_key_violation is nonnull) { - list($msg, $row_id) = $unique_key_violation; - // is this an "INSERT ... ON DUPLICATE KEY UPDATE?" - // if so, this is where we apply the updates - if (!C\is_empty($this->updateExpressions)) { - $existing_row = $table[$row_id]; - list($affected, $table) = $this->applySet( - $conn, - $database, - $table_name, - dict[$row_id => $existing_row], - $table, - $this->updateExpressions, - $schema, - $row, - ); - // MySQL always counts dupe inserts twice intentionally - $rows_affected += $affected * 2; - continue; - } else if ($this->ignoreDupes) { - // silently continue if INSERT IGNORE was specified - continue; - } else if (!QueryContext::$relaxUniqueConstraints) { - throw new SQLFakeUniqueKeyViolation($msg); - } else { - continue; - } - } - $table[] = $row; - $rows_affected++; - } + // check for unique key violations + $unique_key_violation = DataIntegrity::checkUniqueConstraints($table, $row, $schema); + if ($unique_key_violation is nonnull) { + list($msg, $row_id) = $unique_key_violation; + // is this an "INSERT ... ON DUPLICATE KEY UPDATE?" + // if so, this is where we apply the updates + if (!C\is_empty($this->updateExpressions)) { + $existing_row = $table[$row_id]; + list($affected, $table) = $this->applySet( + $conn, + $database, + $table_name, + dict[$row_id => $existing_row], + $table, + $this->updateExpressions, + $schema, + $row, + ); + // MySQL always counts dupe inserts twice intentionally + $rows_affected += $affected * 2; + continue; + } else if ($this->ignoreDupes) { + // silently continue if INSERT IGNORE was specified + continue; + } else if (!QueryContext::$relaxUniqueConstraints) { + throw new SQLFakeUniqueKeyViolation($msg); + } else { + continue; + } + } + $table[] = $row; + $rows_affected++; + } - // write it back to the database - $conn->getServer()->saveTable($database, $table_name, $table); - return $rows_affected; - } + // write it back to the database + $conn->getServer()->saveTable($database, $table_name, $table); + return $rows_affected; + } } diff --git a/src/Query/JoinProcessor.php b/src/Query/JoinProcessor.php index 300948a..169fc4f 100644 --- a/src/Query/JoinProcessor.php +++ b/src/Query/JoinProcessor.php @@ -9,302 +9,300 @@ */ abstract final class JoinProcessor { - // a sentinel to be used as a dict key for null values - const string NULL_SENTINEL = 'SLACK_SQLFAKE_NULL_SENTINEL'; - - public static function process( - AsyncMysqlConnection $conn, - dataset $left_dataset, - dataset $right_dataset, - string $right_table_name, - JoinType $join_type, - ?JoinOperator $_ref_type, - ?Expression $ref_clause, - ?table_schema $right_schema, - ): dataset { - - // MySQL supports JOIN (inner), LEFT OUTER JOIN, RIGHT OUTER JOIN, and implicitly CROSS JOIN (which uses commas), NATURAL - // conditions can be specified with ON or with USING () - // does not support FULL OUTER JOIN - - $out = vec[]; - - // filter can stay as a placeholder for NATURAL joins and CROSS joins which don't have explicit filter clauses - $filter = $ref_clause ?? new PlaceholderExpression(); - - // a special and extremely common case is joining on the comparison of two columns - // instead of evaluating the same expressions over and over again in nested loops, we can optimize this for a more efficient algorithm - // this is somewhat experimental and different merge strategies could be applied in more situations in the future - if ( - C\count($left_dataset) > 5 && - C\count($right_dataset) > 5 && - $filter is BinaryOperatorExpression && - $filter->left is ColumnExpression && - $filter->right is ColumnExpression && - $filter->operator === Operator::EQUALS && - ($join_type === JoinType::JOIN || $join_type === JoinType::STRAIGHT || $join_type === JoinType::LEFT) - ) { - return static::processHashJoin( - $conn, - $left_dataset, - $right_dataset, - $right_table_name, - $join_type, - $_ref_type, - $filter, - $right_schema, - ); - } - - switch ($join_type) { - case JoinType::JOIN: - case JoinType::STRAIGHT: - // straight join is just a query planner optimization of INNER JOIN, - // and it is actually what we are doing here anyway - foreach ($left_dataset as $row) { - foreach ($right_dataset as $r) { - $candidate_row = Dict\merge($row, $r); - if ((bool)$filter->evaluate($candidate_row, $conn)) { - $out[] = $candidate_row; - } - } - } - break; - case JoinType::LEFT: - // for left outer joins, the null placeholder represents an appropriate number of nulled-out columns - // for the case where no rows in the right table match the left table, - // this null placeholder row is merged into the data set for that row - $null_placeholder = dict[]; - if ($right_schema !== null) { - foreach ($right_schema['fields'] as $field) { - $null_placeholder["{$right_table_name}.{$field['name']}"] = null; - } - } - - foreach ($left_dataset as $row) { - $any_match = false; - foreach ($right_dataset as $r) { - $candidate_row = Dict\merge($row, $r); - if ((bool)$filter->evaluate($candidate_row, $conn)) { - $out[] = $candidate_row; - $any_match = true; - } - } - - // for a left join, if no rows in the joined table matched filters - // we need to insert one row in with NULL for each of the target table columns - if (!$any_match) { - // if we have schema for the right table, use a null placeholder row with all the fields set to null - if ($right_schema !== null) { - $out[] = Dict\merge($row, $null_placeholder); - } else { - $out[] = $row; - } - } - } - break; - case JoinType::RIGHT: - // TODO: calculating the null placeholder set here is actually complex, - // we need to get a list of all columns from the schemas for all previous tables in the join sequence - - $null_placeholder = dict[]; - if ($right_schema !== null) { - foreach ($right_schema['fields'] as $field) { - $null_placeholder["{$right_table_name}.{$field['name']}"] = null; - } - } - - foreach ($right_dataset as $raw) { - $any_match = false; - foreach ($left_dataset as $row) { - $candidate_row = Dict\merge($row, $raw); - if ((bool)$filter->evaluate($candidate_row, $conn)) { - $out[] = $candidate_row; - $any_match = true; - } - } - - if (!$any_match) { - $out[] = $raw; - // TODO set null placeholder - } - } - break; - case JoinType::CROSS: - foreach ($left_dataset as $row) { - foreach ($right_dataset as $r) { - $out[] = Dict\merge($row, $r); - } - } - break; - case JoinType::NATURAL: - // unlike other join filters this one has to be built at runtime, using the list of columns that exists between the two tables - // for each column in the target table, see if there is a matching column in the rest of the data set. if so, make a filter that they must be equal. - $filter = self::buildNaturalJoinFilter($left_dataset, $right_dataset); - - // now basically just do a regular join - foreach ($left_dataset as $row) { - foreach ($right_dataset as $r) { - $candidate_row = Dict\merge($row, $r); - if ((bool)$filter->evaluate($candidate_row, $conn)) { - $out[] = $candidate_row; - } - } - } - break; - } - - return $out; - } - - /** - * Somewhat similar to USING clause, but we're just looking for all column names that match between the two tables - */ - protected static function buildNaturalJoinFilter(dataset $left_dataset, dataset $right_dataset): Expression { - $filter = null; - - $left = C\first($left_dataset); - $right = C\first($right_dataset); - if ($left === null || $right === null) { - throw new SQLFakeParseException('Attempted NATURAL join with no data present'); - } - foreach ($left as $column => $_val) { - $name = Str\split($column, '.') |> C\lastx($$); - foreach ($right as $col => $_v) { - $colname = Str\split($col, '.') |> C\lastx($$); - if ($colname === $name) { - $filter = self::addJoinFilterExpression($filter, $column, $col); - } - } - } - - // MySQL actually doesn't throw if there's no matching columns, but I think we can take the liberty to assume it's not what you meant to do and throw here - if ($filter === null) { - throw new SQLFakeParseException('NATURAL join keyword was used with tables that do not share any column names'); - } - - return $filter; - } - - /** - * For building a NATURAL join filter - */ - protected static function addJoinFilterExpression( - ?Expression $filter, - string $left_column, - string $right_column, - ): BinaryOperatorExpression { - - $left = new ColumnExpression( - shape('type' => TokenType::IDENTIFIER, 'value' => $left_column, 'raw' => $left_column), - ); - $right = new ColumnExpression( - shape('type' => TokenType::IDENTIFIER, 'value' => $right_column, 'raw' => $right_column), - ); - - // making a binary expression ensuring those two tokens are equal - $expr = new BinaryOperatorExpression($left, /* $negated */ false, Operator::EQUALS, $right); - - // if this is not the first condition, make an AND that wraps the current and new filter - if ($filter !== null) { - $filter = new BinaryOperatorExpression($filter, /* $negated */ false, Operator::AND, $expr); - } else { - $filter = $expr; - } - - return $filter; - } - - /** - * Coerce a column value to a string which can be used as a key - * for joining two datasets - * a sentinel is used for NULL, since that is not a valid arraykey - */ - private static function coerceToArrayKey(mixed $value): arraykey { - return $value is null ? self::NULL_SENTINEL : (string)$value; - } - - /** - * a specialized join algorithm that computes a hash containing the computed column results - * and row pointers for each row on one side - * this reduces repeated comparisons and is a performance improvement - */ - private static function processHashJoin( - AsyncMysqlConnection $conn, - dataset $left_dataset, - dataset $right_dataset, - string $right_table_name, - JoinType $join_type, - ?JoinOperator $_ref_type, - BinaryOperatorExpression $filter, - ?table_schema $right_schema, - ): dataset { - $left = $filter->left as ColumnExpression; - $right = $filter->right as ColumnExpression; - if ($left->tableName() === $right_table_name) { - // filter order may not match table order - // if the left filter is for the right table, swap the filters - list($left, $right) = vec[$right, $left]; - } - $out = vec[]; - - // evaluate the column expression once per row in the right dataset first, building up a temporary table that groups all rows together for each value - // multiple rows may have the same value. their ids in the original dataset are stored in a keyset - $right_temp_table = dict[]; - foreach ($right_dataset as $k => $r) { - $value = $right->evaluate($r, $conn); - $value = self::coerceToArrayKey($value); - $right_temp_table[$value] ??= keyset[]; - $right_temp_table[$value][] = $k; - } - - switch ($join_type) { - case JoinType::JOIN: - case JoinType::STRAIGHT: - foreach ($left_dataset as $row) { - $value = $left->evaluate($row, $conn) |> static::coerceToArrayKey($$); - // find all rows matching this value in the right temp table and get their full rows - foreach ($right_temp_table[$value] ?? keyset[] as $k) { - $out[] = Dict\merge($row, $right_dataset[$k]); - } - } - break; - case JoinType::LEFT: - // for left outer joins, the null placeholder represents an appropriate number of nulled-out columns - // for the case where no rows in the right table match the left table, - // this null placeholder row is merged into the data set for that row - $null_placeholder = dict[]; - if ($right_schema !== null) { - foreach ($right_schema['fields'] as $field) { - $null_placeholder["{$right_table_name}.{$field['name']}"] = null; - } - } - - foreach ($left_dataset as $row) { - $any_match = false; - $value = $left->evaluate($row, $conn) |> static::coerceToArrayKey($$); - foreach ($right_dataset as $r) { - foreach ($right_temp_table[$value] ?? keyset[] as $k) { - $out[] = Dict\merge($row, $right_dataset[$k]); - $any_match = true; - } - } - - // for a left join, if no rows in the joined table matched filters - // we need to insert one row in with NULL for each of the target table columns - if (!$any_match) { - // if we have schema for the right table, use a null placeholder row with all the fields set to null - if ($right_schema !== null) { - $out[] = Dict\merge($row, $null_placeholder); - } else { - $out[] = $row; - } - } - } - break; - default: - invariant_violation('unreachable'); - } - return $out; - } + // a sentinel to be used as a dict key for null values + const string NULL_SENTINEL = 'SLACK_SQLFAKE_NULL_SENTINEL'; + + public static function process( + AsyncMysqlConnection $conn, + dataset $left_dataset, + dataset $right_dataset, + string $right_table_name, + JoinType $join_type, + ?JoinOperator $_ref_type, + ?Expression $ref_clause, + ?table_schema $right_schema, + ): dataset { + + // MySQL supports JOIN (inner), LEFT OUTER JOIN, RIGHT OUTER JOIN, and implicitly CROSS JOIN (which uses commas), NATURAL + // conditions can be specified with ON or with USING () + // does not support FULL OUTER JOIN + + $out = vec[]; + + // filter can stay as a placeholder for NATURAL joins and CROSS joins which don't have explicit filter clauses + $filter = $ref_clause ?? new PlaceholderExpression(); + + // a special and extremely common case is joining on the comparison of two columns + // instead of evaluating the same expressions over and over again in nested loops, we can optimize this for a more efficient algorithm + // this is somewhat experimental and different merge strategies could be applied in more situations in the future + if ( + C\count($left_dataset) > 5 && + C\count($right_dataset) > 5 && + $filter is BinaryOperatorExpression && + $filter->left is ColumnExpression && + $filter->right is ColumnExpression && + $filter->operator === Operator::EQUALS && + ($join_type === JoinType::JOIN || $join_type === JoinType::STRAIGHT || $join_type === JoinType::LEFT) + ) { + return static::processHashJoin( + $conn, + $left_dataset, + $right_dataset, + $right_table_name, + $join_type, + $_ref_type, + $filter, + $right_schema, + ); + } + + switch ($join_type) { + case JoinType::JOIN: + case JoinType::STRAIGHT: + // straight join is just a query planner optimization of INNER JOIN, + // and it is actually what we are doing here anyway + foreach ($left_dataset as $row) { + foreach ($right_dataset as $r) { + $candidate_row = Dict\merge($row, $r); + if ((bool)$filter->evaluate($candidate_row, $conn)) { + $out[] = $candidate_row; + } + } + } + break; + case JoinType::LEFT: + // for left outer joins, the null placeholder represents an appropriate number of nulled-out columns + // for the case where no rows in the right table match the left table, + // this null placeholder row is merged into the data set for that row + $null_placeholder = dict[]; + if ($right_schema !== null) { + foreach ($right_schema['fields'] as $field) { + $null_placeholder["{$right_table_name}.{$field['name']}"] = null; + } + } + + foreach ($left_dataset as $row) { + $any_match = false; + foreach ($right_dataset as $r) { + $candidate_row = Dict\merge($row, $r); + if ((bool)$filter->evaluate($candidate_row, $conn)) { + $out[] = $candidate_row; + $any_match = true; + } + } + + // for a left join, if no rows in the joined table matched filters + // we need to insert one row in with NULL for each of the target table columns + if (!$any_match) { + // if we have schema for the right table, use a null placeholder row with all the fields set to null + if ($right_schema !== null) { + $out[] = Dict\merge($row, $null_placeholder); + } else { + $out[] = $row; + } + } + } + break; + case JoinType::RIGHT: + // TODO: calculating the null placeholder set here is actually complex, + // we need to get a list of all columns from the schemas for all previous tables in the join sequence + + $null_placeholder = dict[]; + if ($right_schema !== null) { + foreach ($right_schema['fields'] as $field) { + $null_placeholder["{$right_table_name}.{$field['name']}"] = null; + } + } + + foreach ($right_dataset as $raw) { + $any_match = false; + foreach ($left_dataset as $row) { + $candidate_row = Dict\merge($row, $raw); + if ((bool)$filter->evaluate($candidate_row, $conn)) { + $out[] = $candidate_row; + $any_match = true; + } + } + + if (!$any_match) { + $out[] = $raw; + // TODO set null placeholder + } + } + break; + case JoinType::CROSS: + foreach ($left_dataset as $row) { + foreach ($right_dataset as $r) { + $out[] = Dict\merge($row, $r); + } + } + break; + case JoinType::NATURAL: + // unlike other join filters this one has to be built at runtime, using the list of columns that exists between the two tables + // for each column in the target table, see if there is a matching column in the rest of the data set. if so, make a filter that they must be equal. + $filter = self::buildNaturalJoinFilter($left_dataset, $right_dataset); + + // now basically just do a regular join + foreach ($left_dataset as $row) { + foreach ($right_dataset as $r) { + $candidate_row = Dict\merge($row, $r); + if ((bool)$filter->evaluate($candidate_row, $conn)) { + $out[] = $candidate_row; + } + } + } + break; + } + + return $out; + } + + /** + * Somewhat similar to USING clause, but we're just looking for all column names that match between the two tables + */ + protected static function buildNaturalJoinFilter(dataset $left_dataset, dataset $right_dataset): Expression { + $filter = null; + + $left = C\first($left_dataset); + $right = C\first($right_dataset); + if ($left === null || $right === null) { + throw new SQLFakeParseException('Attempted NATURAL join with no data present'); + } + foreach ($left as $column => $_val) { + $name = Str\split($column, '.') |> C\lastx($$); + foreach ($right as $col => $_v) { + $colname = Str\split($col, '.') |> C\lastx($$); + if ($colname === $name) { + $filter = self::addJoinFilterExpression($filter, $column, $col); + } + } + } + + // MySQL actually doesn't throw if there's no matching columns, but I think we can take the liberty to assume it's not what you meant to do and throw here + if ($filter === null) { + throw new SQLFakeParseException('NATURAL join keyword was used with tables that do not share any column names'); + } + + return $filter; + } + + /** + * For building a NATURAL join filter + */ + protected static function addJoinFilterExpression( + ?Expression $filter, + string $left_column, + string $right_column, + ): BinaryOperatorExpression { + + $left = + new ColumnExpression(shape('type' => TokenType::IDENTIFIER, 'value' => $left_column, 'raw' => $left_column)); + $right = + new ColumnExpression(shape('type' => TokenType::IDENTIFIER, 'value' => $right_column, 'raw' => $right_column)); + + // making a binary expression ensuring those two tokens are equal + $expr = new BinaryOperatorExpression($left, /* $negated */ false, Operator::EQUALS, $right); + + // if this is not the first condition, make an AND that wraps the current and new filter + if ($filter !== null) { + $filter = new BinaryOperatorExpression($filter, /* $negated */ false, Operator::AND, $expr); + } else { + $filter = $expr; + } + + return $filter; + } + + /** + * Coerce a column value to a string which can be used as a key + * for joining two datasets + * a sentinel is used for NULL, since that is not a valid arraykey + */ + private static function coerceToArrayKey(mixed $value): arraykey { + return $value is null ? self::NULL_SENTINEL : (string)$value; + } + + /** + * a specialized join algorithm that computes a hash containing the computed column results + * and row pointers for each row on one side + * this reduces repeated comparisons and is a performance improvement + */ + private static function processHashJoin( + AsyncMysqlConnection $conn, + dataset $left_dataset, + dataset $right_dataset, + string $right_table_name, + JoinType $join_type, + ?JoinOperator $_ref_type, + BinaryOperatorExpression $filter, + ?table_schema $right_schema, + ): dataset { + $left = $filter->left as ColumnExpression; + $right = $filter->right as ColumnExpression; + if ($left->tableName() === $right_table_name) { + // filter order may not match table order + // if the left filter is for the right table, swap the filters + list($left, $right) = vec[$right, $left]; + } + $out = vec[]; + + // evaluate the column expression once per row in the right dataset first, building up a temporary table that groups all rows together for each value + // multiple rows may have the same value. their ids in the original dataset are stored in a keyset + $right_temp_table = dict[]; + foreach ($right_dataset as $k => $r) { + $value = $right->evaluate($r, $conn); + $value = self::coerceToArrayKey($value); + $right_temp_table[$value] ??= keyset[]; + $right_temp_table[$value][] = $k; + } + + switch ($join_type) { + case JoinType::JOIN: + case JoinType::STRAIGHT: + foreach ($left_dataset as $row) { + $value = $left->evaluate($row, $conn) |> static::coerceToArrayKey($$); + // find all rows matching this value in the right temp table and get their full rows + foreach ($right_temp_table[$value] ?? keyset[] as $k) { + $out[] = Dict\merge($row, $right_dataset[$k]); + } + } + break; + case JoinType::LEFT: + // for left outer joins, the null placeholder represents an appropriate number of nulled-out columns + // for the case where no rows in the right table match the left table, + // this null placeholder row is merged into the data set for that row + $null_placeholder = dict[]; + if ($right_schema !== null) { + foreach ($right_schema['fields'] as $field) { + $null_placeholder["{$right_table_name}.{$field['name']}"] = null; + } + } + + foreach ($left_dataset as $row) { + $any_match = false; + $value = $left->evaluate($row, $conn) |> static::coerceToArrayKey($$); + foreach ($right_dataset as $r) { + foreach ($right_temp_table[$value] ?? keyset[] as $k) { + $out[] = Dict\merge($row, $right_dataset[$k]); + $any_match = true; + } + } + + // for a left join, if no rows in the joined table matched filters + // we need to insert one row in with NULL for each of the target table columns + if (!$any_match) { + // if we have schema for the right table, use a null placeholder row with all the fields set to null + if ($right_schema !== null) { + $out[] = Dict\merge($row, $null_placeholder); + } else { + $out[] = $row; + } + } + } + break; + default: + invariant_violation('unreachable'); + } + return $out; + } } diff --git a/src/Query/Query.php b/src/Query/Query.php index b08661b..f4c6e38 100644 --- a/src/Query/Query.php +++ b/src/Query/Query.php @@ -12,220 +12,220 @@ */ abstract class Query { - public ?Expression $whereClause = null; - public ?order_by_clause $orderBy = null; - public ?limit_clause $limitClause = null; - - /** - * The initial query that was executed, no longer needed after parsing but retained for - * debugging and logging - */ - public string $sql; - public bool $ignoreDupes = false; - - protected function applyWhere(AsyncMysqlConnection $conn, dataset $data): dataset { - $where = $this->whereClause; - if ($where === null) { - // no where clause? cool! just return the given data - return $data; - } - - return Dict\filter($data, $row ==> (bool)$where->evaluate($row, $conn)); - } - - /** - * Apply the ORDER BY clause to sort the rows - */ - protected function applyOrderBy(AsyncMysqlConnection $_conn, dataset $data): dataset { - $order_by = $this->orderBy; - if ($order_by === null) { - return $data; - } - - // allow all column expressions to fall through to the full row - foreach ($order_by as $rule) { - $expr = $rule['expression']; - if ($expr is ColumnExpression && $expr->tableName() === null) { - $expr->allowFallthrough(); - } - } - - // sort function applies all ORDER BY criteria to compare two rows - $sort_fun = (row $a, row $b): int ==> { - foreach ($order_by as $rule) { - // in applySelect, the order by expressions are pre-evaluated and saved on the row with their names as keys, - // so we don't need to evaluate them again here - $value_a = $a[$rule['expression']->name]; - $value_b = $b[$rule['expression']->name]; - - if ($value_a != $value_b) { - if ($value_a is num && $value_b is num) { - return ( - ((float)$value_a < (float)$value_b ? 1 : 0) ^ (($rule['direction'] === SortDirection::DESC) ? 1 : 0) - ) - ? -1 - : 1; - } else { - return ( - // Use string comparison explicity to handle lexicographical ordering of things like '125' < '5' - (((Str\compare((string)$value_a, (string)$value_b)) < 0) ? 1 : 0) ^ - (($rule['direction'] === SortDirection::DESC) ? 1 : 0) - ) - ? -1 - : 1; - } - - } - } - return 0; - }; - - // Work around default sorting behavior to provide a usort that looks like MySQL, where equal values are ordered deterministically - // record the keys in a dict for usort - $data_temp = dict[]; - foreach ($data as $i => $item) { - $data_temp[$i] = tuple($i, $item); - } - - $data_temp = Dict\sort($data_temp, ((int, dict) $a, (int, dict) $b): int ==> { - $result = $sort_fun($a[1], $b[1]); - - return $result === 0 ? $b[0] - $a[0] : $result; - }); - - // re-key the input dataset - $data_temp = vec($data_temp); - // dicts maintain insert order. the keys will be inserted out of order but have to match the original - // keys for updates/deletes to be able to delete the right rows - $data = dict[]; - foreach ($data_temp as $item) { - $data[$item[0]] = $item[1]; - } - - return $data; - } - - protected function applyLimit(dataset $data): dataset { - $limit = $this->limitClause; - if ($limit === null) { - return $data; - } - - // keys in this dict are intentionally out of order if an ORDER BY clause occurred - // so first we get the ordered keys, then slice that list by the limit clause, then select only those keys - return Vec\keys($data) - |> Vec\slice($$, $limit['offset'], $limit['rowcount']) - |> Dict\select_keys($data, $$); - } - - /** - * Parses a table name that may contain a . to reference another database - * Returns the fully qualified database name and table name as a tuple - * If there is no ".", the database name will be the connection's current database - */ - public static function parseTableName(AsyncMysqlConnection $conn, string $table): (string, string) { - // referencing a table from another database on the same server? - if (Str\contains($table, '.')) { - $parts = Str\split($table, '.'); - if (C\count($parts) !== 2) { - throw new SQLFakeRuntimeException("Table name $table has too many parts"); - } - list($database, $table_name) = $parts; - return tuple($database, $table_name); - } else { - // otherwise use connection context's database - $database = $conn->getDatabase(); - return tuple($database, $table); - } - } - - /** - * Apply the "SET" clause of an UPDATE, or "ON DUPLICATE KEY UPDATE" - */ - protected function applySet( - AsyncMysqlConnection $conn, - string $database, - string $table_name, - dataset $filtered_rows, - dataset $original_table, - vec $set_clause, - ?table_schema $table_schema, - /* for dupe inserts only */ - ?row $values = null, - ): (int, vec>) { - - $original_table as vec<_>; - - $valid_fields = null; - if ($table_schema !== null) { - $valid_fields = Keyset\map($table_schema['fields'], $field ==> $field['name']); - } - - $set_clauses = vec[]; - foreach ($set_clause as $expression) { - // the parser already asserts this at parse time - $left = $expression->left as ColumnExpression; - $right = $expression->right as nonnull; - $column = $left->name; - - // If we know the valid fields for this table, only allow setting those - if ($valid_fields !== null) { - if (!C\contains($valid_fields, $column)) { - throw new SQLFakeRuntimeException("Invalid update column {$column}"); - } - } - - $set_clauses[] = shape('column' => $column, 'expression' => $right); - } - - $update_count = 0; - - foreach ($filtered_rows as $row_id => $row) { - $changes_found = false; - - // a copy of the $row to be updated - $update_row = $row; - if ($values is nonnull) { - // this is a bit of a hack to make the VALUES() function work without changing the - // interface of all ->evaluate() expressions to include the values list as well - // we put the values on the row as though they were another table - // we do this on a copy so that we don't accidentally save these to the table - foreach ($values as $col => $val) { - $update_row['sql_fake_values.'.$col] = $val; - } - } - foreach ($set_clauses as $clause) { - $existing_value = $row[$clause['column']] ?? null; - $expr = $clause['expression']; - $new_value = $expr->evaluate($update_row, $conn); - - if ($new_value !== $existing_value) { - $row[$clause['column']] = $new_value; - $changes_found = true; - } - } - - if ($changes_found) { - if ($table_schema is nonnull) { - // throw on invalid data types if strict mode - $row = DataIntegrity::coerceToSchema($row, $table_schema); - $result = DataIntegrity::checkUniqueConstraints($original_table, $row, $table_schema, $row_id); - if ($result is nonnull) { - if ($this->ignoreDupes) { - continue; - } - if (!QueryContext::$relaxUniqueConstraints) { - throw new SQLFakeUniqueKeyViolation($result[0]); - } - } - } - $original_table[$row_id] = $row; - $update_count++; - } - } - - // write it back to the database - $conn->getServer()->saveTable($database, $table_name, $original_table); - return tuple($update_count, $original_table); - } + public ?Expression $whereClause = null; + public ?order_by_clause $orderBy = null; + public ?limit_clause $limitClause = null; + + /** + * The initial query that was executed, no longer needed after parsing but retained for + * debugging and logging + */ + public string $sql; + public bool $ignoreDupes = false; + + protected function applyWhere(AsyncMysqlConnection $conn, dataset $data): dataset { + $where = $this->whereClause; + if ($where === null) { + // no where clause? cool! just return the given data + return $data; + } + + return Dict\filter($data, $row ==> (bool)$where->evaluate($row, $conn)); + } + + /** + * Apply the ORDER BY clause to sort the rows + */ + protected function applyOrderBy(AsyncMysqlConnection $_conn, dataset $data): dataset { + $order_by = $this->orderBy; + if ($order_by === null) { + return $data; + } + + // allow all column expressions to fall through to the full row + foreach ($order_by as $rule) { + $expr = $rule['expression']; + if ($expr is ColumnExpression && $expr->tableName() === null) { + $expr->allowFallthrough(); + } + } + + // sort function applies all ORDER BY criteria to compare two rows + $sort_fun = (row $a, row $b): int ==> { + foreach ($order_by as $rule) { + // in applySelect, the order by expressions are pre-evaluated and saved on the row with their names as keys, + // so we don't need to evaluate them again here + $value_a = $a[$rule['expression']->name]; + $value_b = $b[$rule['expression']->name]; + + if ($value_a != $value_b) { + if ($value_a is num && $value_b is num) { + return ( + ((float)$value_a < (float)$value_b ? 1 : 0) ^ (($rule['direction'] === SortDirection::DESC) ? 1 : 0) + ) + ? -1 + : 1; + } else { + return ( + // Use string comparison explicity to handle lexicographical ordering of things like '125' < '5' + (((Str\compare((string)$value_a, (string)$value_b)) < 0) ? 1 : 0) ^ + (($rule['direction'] === SortDirection::DESC) ? 1 : 0) + ) + ? -1 + : 1; + } + + } + } + return 0; + }; + + // Work around default sorting behavior to provide a usort that looks like MySQL, where equal values are ordered deterministically + // record the keys in a dict for usort + $data_temp = dict[]; + foreach ($data as $i => $item) { + $data_temp[$i] = tuple($i, $item); + } + + $data_temp = Dict\sort($data_temp, ((int, dict) $a, (int, dict) $b): int ==> { + $result = $sort_fun($a[1], $b[1]); + + return $result === 0 ? $b[0] - $a[0] : $result; + }); + + // re-key the input dataset + $data_temp = vec($data_temp); + // dicts maintain insert order. the keys will be inserted out of order but have to match the original + // keys for updates/deletes to be able to delete the right rows + $data = dict[]; + foreach ($data_temp as $item) { + $data[$item[0]] = $item[1]; + } + + return $data; + } + + protected function applyLimit(dataset $data): dataset { + $limit = $this->limitClause; + if ($limit === null) { + return $data; + } + + // keys in this dict are intentionally out of order if an ORDER BY clause occurred + // so first we get the ordered keys, then slice that list by the limit clause, then select only those keys + return Vec\keys($data) + |> Vec\slice($$, $limit['offset'], $limit['rowcount']) + |> Dict\select_keys($data, $$); + } + + /** + * Parses a table name that may contain a . to reference another database + * Returns the fully qualified database name and table name as a tuple + * If there is no ".", the database name will be the connection's current database + */ + public static function parseTableName(AsyncMysqlConnection $conn, string $table): (string, string) { + // referencing a table from another database on the same server? + if (Str\contains($table, '.')) { + $parts = Str\split($table, '.'); + if (C\count($parts) !== 2) { + throw new SQLFakeRuntimeException("Table name $table has too many parts"); + } + list($database, $table_name) = $parts; + return tuple($database, $table_name); + } else { + // otherwise use connection context's database + $database = $conn->getDatabase(); + return tuple($database, $table); + } + } + + /** + * Apply the "SET" clause of an UPDATE, or "ON DUPLICATE KEY UPDATE" + */ + protected function applySet( + AsyncMysqlConnection $conn, + string $database, + string $table_name, + dataset $filtered_rows, + dataset $original_table, + vec $set_clause, + ?table_schema $table_schema, + /* for dupe inserts only */ + ?row $values = null, + ): (int, vec>) { + + $original_table as vec<_>; + + $valid_fields = null; + if ($table_schema !== null) { + $valid_fields = Keyset\map($table_schema['fields'], $field ==> $field['name']); + } + + $set_clauses = vec[]; + foreach ($set_clause as $expression) { + // the parser already asserts this at parse time + $left = $expression->left as ColumnExpression; + $right = $expression->right as nonnull; + $column = $left->name; + + // If we know the valid fields for this table, only allow setting those + if ($valid_fields !== null) { + if (!C\contains($valid_fields, $column)) { + throw new SQLFakeRuntimeException("Invalid update column {$column}"); + } + } + + $set_clauses[] = shape('column' => $column, 'expression' => $right); + } + + $update_count = 0; + + foreach ($filtered_rows as $row_id => $row) { + $changes_found = false; + + // a copy of the $row to be updated + $update_row = $row; + if ($values is nonnull) { + // this is a bit of a hack to make the VALUES() function work without changing the + // interface of all ->evaluate() expressions to include the values list as well + // we put the values on the row as though they were another table + // we do this on a copy so that we don't accidentally save these to the table + foreach ($values as $col => $val) { + $update_row['sql_fake_values.'.$col] = $val; + } + } + foreach ($set_clauses as $clause) { + $existing_value = $row[$clause['column']] ?? null; + $expr = $clause['expression']; + $new_value = $expr->evaluate($update_row, $conn); + + if ($new_value !== $existing_value) { + $row[$clause['column']] = $new_value; + $changes_found = true; + } + } + + if ($changes_found) { + if ($table_schema is nonnull) { + // throw on invalid data types if strict mode + $row = DataIntegrity::coerceToSchema($row, $table_schema); + $result = DataIntegrity::checkUniqueConstraints($original_table, $row, $table_schema, $row_id); + if ($result is nonnull) { + if ($this->ignoreDupes) { + continue; + } + if (!QueryContext::$relaxUniqueConstraints) { + throw new SQLFakeUniqueKeyViolation($result[0]); + } + } + } + $original_table[$row_id] = $row; + $update_count++; + } + } + + // write it back to the database + $conn->getServer()->saveTable($database, $table_name, $original_table); + return tuple($update_count, $original_table); + } } diff --git a/src/Query/SelectQuery.php b/src/Query/SelectQuery.php index 3abcffc..2c72071 100644 --- a/src/Query/SelectQuery.php +++ b/src/Query/SelectQuery.php @@ -6,305 +6,305 @@ final class SelectQuery extends Query { - public vec $selectExpressions = vec[]; - public ?FromClause $fromClause = null; - public ?vec $groupBy = null; - public ?Expression $havingClause = null; - public vec MultiOperand, 'query' => SelectQuery)> $multiQueries = vec[]; - - public keyset $options = keyset[]; - // this tracks whether we found a comma in between expressions - public bool $needsSeparator = false; - public bool $mostRecentHasAlias = false; - - public function __construct(public string $sql) {} - - public function addSelectExpression(Expression $expr): void { - if ($this->needsSeparator) { - throw new SQLFakeParseException('Unexpected expression!'); - } - $this->selectExpressions[] = $expr; - $this->needsSeparator = true; - $this->mostRecentHasAlias = false; - } - - public function addOption(string $option): void { - $this->options[] = $option; - } - - public function aliasRecentExpression(string $name): void { - $k = C\last_key($this->selectExpressions); - if ($k === null || $this->mostRecentHasAlias) { - throw new SQLFakeParseException('Unexpected AS'); - } - $this->selectExpressions[$k]->name = $name; - $this->mostRecentHasAlias = true; - } - - public function addMultiQuery(MultiOperand $type, SelectQuery $query): void { - $this->multiQueries[] = shape('type' => $type, 'query' => $query); - } - - /** - * Run the query - * The 2nd parameter is for supporting correlated subqueries, not currently supported - */ - public function execute(AsyncMysqlConnection $conn, ?row $_ = null): dataset { - - return - // FROM clause handling - builds a data set including extracting rows from tables, applying joins - $this->applyFrom($conn) - // WHERE caluse - filter out any rows that don't match it - |> $this->applyWhere($conn, $$) - // GROUP BY clause - may group the rows if necessary. all clauses after this need to know how to handled both grouped and ungrouped inputs - |> $this->applyGroupBy($conn, $$) - // HAVING clause, filter out any rows not matching it - |> $this->applyHaving($conn, $$) - // SELECT clause. this is where we actually run the expressions in the SELECT - |> $this->applySelect($conn, $$) - // ORDER BY. this runs after select because it could use expressions from the select - |> $this->applyOrderBy($conn, $$) - // LIMIT clause - |> $this->applyLimit($$) - // filter out any data that we needed for the ORDER BY that is not supposed to be returned - |> $this->removeOrderByExtras($conn, $$) - // this recurses in case there are any UNION, EXCEPT, INTERSECT keywords - |> $this->processMultiQuery($conn, $$); - } - - /** - * The FROM clause of the query gets processed first, retrieving data from tables, executing subqueries, and handling joins - * This is also where we build up the $columns list which is commonly used throughout the entire library to map column references to indexes in this dataset - */ - protected function applyFrom(AsyncMysqlConnection $conn): dataset { - - $from = $this->fromClause; - if ($from === null) { - // we put one empty row when there is no FROM so that queries like "SELECT 1" will return a row - return vec[dict[]]; - } - - return $from->process($conn, $this->sql); - } - - /** - * Apply the GROUP BY clause to group rows by a set of expressions. - * This may also group the rows if the select list contains an aggregate function, which requires an implicit grouping - */ - protected function applyGroupBy(AsyncMysqlConnection $conn, dataset $data): dataset { - $group_by = $this->groupBy; - $select_expressions = $this->selectExpressions; - if ($group_by !== null) { - $grouped_data = dict[]; - foreach ($data as $row) { - $hashes = ''; - foreach ($group_by as $expr) { - $hashes .= \sha1((string)$expr->evaluate($row, $conn)); - } - $hash = \sha1($hashes); - if (!C\contains_key($grouped_data, $hash)) { - $grouped_data[$hash] = dict[]; - } - $count = C\count($grouped_data[$hash]); - $grouped_data[$hash][(string)$count] = $row; - } - - $data = vec($grouped_data); - } else { - $found_aggregate = false; - foreach ($select_expressions as $expr) { - if ($expr is FunctionExpression && $expr->isAggregate()) { - $found_aggregate = true; - break; - } - } - - // if we have an aggregate function in the select clause but no group by, do an implicit group that puts all rows in one grouping - // this makes things like "SELECT COUNT(*) FROM mytable" work - if ($found_aggregate) { - return vec[Dict\map_keys($data, $k ==> (string)$k)]; - } - } - - // vec[dict[0 => dict['0' => $row, '1' => $row], 1 => dict['0' => $row]] - - return $data; - } - - /** - * Apply the HAVING clause to every (maybe grouped) row in the data set. Only return truthy results. - */ - protected function applyHaving(AsyncMysqlConnection $conn, dataset $data): dataset { - $havingClause = $this->havingClause; - if ($havingClause is nonnull) { - return Vec\filter($data, $row ==> (bool)$havingClause->evaluate($row, $conn)); - } - - return $data; - } - - /** - * Generate the result set containing SELECT expressions - */ - protected function applySelect(AsyncMysqlConnection $conn, dataset $data): dataset { - - // The ORDER BY portion of queries run after this SELECT code. - // However, it is possible to ORDER BY a field that we do not intend to SELECT by, - // or put another way: that we do not intend to return from the query. - // - // But we have to include those here anyway so we can perform the ORDER BY - // and then throw them away. - - $order_by_expressions = $this->orderBy ?? vec[]; - - $out = vec[]; - - // ok now you got that filter, let's do the formatting - foreach ($data as $row) { - $formatted_row = dict[]; - - foreach ($this->selectExpressions as $expr) { - if ($expr is ColumnExpression && $expr->name === '*') { - // if it's a GROUP BY, take the first row from each grouping. - // SELECT * with a GROUP BY effectively picks the first row from each group - $first_value = C\first($row); - if ($first_value is dict<_, _>) { - $row = $first_value; - } - foreach ($row as $col => $val) { - $parts = Str\split((string)$col, '.'); - if ($expr->tableName() is nonnull) { - list($col_table_name, $col_name) = $parts; - if ($col_table_name == $expr->tableName()) { - if (!C\contains_key($formatted_row, $col)) { - $formatted_row[$col_name] = $val; - } - } - } else { - $col = C\last($parts); - if ($col is nonnull) { - $formatted_row[$col] ??= $val; - } - } - - } - continue; - } - - list($name, $val) = $expr->evaluateWithName($row, $conn); - - // subquery: unroll the expression to get the value out - if ($expr is SubqueryExpression) { - invariant($val is KeyedContainer<_, _>, 'subquery results must be KeyedContainer'); - if (C\count($val) > 1) { - throw new SQLFakeRuntimeException('Subquery returned more than one row'); - } - if (C\count($val) === 0) { - $val = null; - } else { - foreach ($val as $r) { - $r as KeyedContainer<_, _>; - if (C\count($r) !== 1) { - throw new SQLFakeRuntimeException('Subquery result should contain 1 column'); - } - $val = C\onlyx($r); - } - } - } - $formatted_row[$name] = $val; - } - - // Adding any fields needed by the ORDER BY not already returned by the SELECT - foreach ($order_by_expressions as $order_by) { - $row as dict<_, _>; - list($name, $val) = $order_by['expression']->evaluateWithName(/* HH_FIXME[4110] generics */ $row, $conn); - $formatted_row[$name] ??= $val; - } - - $out[] = $formatted_row; - } - - if (C\contains_key($this->options, 'DISTINCT')) { - return Vec\unique_by($out, (row $row): string ==> Str\join(Vec\map($row, $col ==> (string)$col), '-')); - } - - return $out; - } - - /** - * Remove fields that we do not SELECT by, but we do ORDER BY - */ - protected function removeOrderByExtras(AsyncMysqlConnection $_conn, dataset $data): dataset { - - $order_by = $this->orderBy; - if ($order_by === null || C\count($data) === 0) { - return $data; - } - - $order_by_names = keyset[]; - $select_field_names = keyset[]; - - foreach ($this->selectExpressions as $expr) { - $name = $expr->name; - // if we are selecting everything we know the field is included - if ($name === '*') { - return $data; - } - if ($name !== null) { - $select_field_names[] = $name; - } - } - - foreach ($order_by as $o) { - $name = $o['expression']->name; - if ($name !== null) { - $order_by_names[] = $name; - } - } - - $remove_fields = Keyset\diff($order_by_names, $select_field_names); - if (C\is_empty($remove_fields)) { - return $data; - } - - // remove the fields we don't want from each row - return Vec\map($data, $row ==> Dict\filter_keys($row, $field ==> !C\contains_key($remove_fields, $field))); - } - - /** - * Process a query that contains multiple queries such as with UNION, INTERSECT, EXCEPT, UNION ALL - */ - protected function processMultiQuery(AsyncMysqlConnection $conn, dataset $data): dataset { - - // function used to stringify rows for comparison - $row_encoder = (row $row): string ==> Str\join(Vec\map($row, $col ==> (string)$col), '-'); - - foreach ($this->multiQueries as $sub) { - // invoke the subquery - $subquery_results = $sub['query']->execute($conn); - - // now put the results together based on the keyword - switch ($sub['type']) { - case MultiOperand::UNION: - // contact the results, then get unique rows by converting all fields to string and comparing a joined-up representation - $data = Vec\concat($data, $subquery_results) |> Vec\unique_by($$, $row_encoder); - break; - case MultiOperand::UNION_ALL: - // just concatenate with no uniqueness - $data = Vec\concat($data, $subquery_results); - break; - case MultiOperand::INTERSECT: - // there's no Vec\intersect_by currently - $encoded_data = Keyset\map($data, $row_encoder); - $data = Vec\filter($subquery_results, $row ==> C\contains_key($encoded_data, $row_encoder($row))); - break; - case MultiOperand::EXCEPT: - $data = Vec\diff_by($data, $subquery_results, $row_encoder); - break; - } - } - - return $data; - } + public vec $selectExpressions = vec[]; + public ?FromClause $fromClause = null; + public ?vec $groupBy = null; + public ?Expression $havingClause = null; + public vec MultiOperand, 'query' => SelectQuery)> $multiQueries = vec[]; + + public keyset $options = keyset[]; + // this tracks whether we found a comma in between expressions + public bool $needsSeparator = false; + public bool $mostRecentHasAlias = false; + + public function __construct(public string $sql) {} + + public function addSelectExpression(Expression $expr): void { + if ($this->needsSeparator) { + throw new SQLFakeParseException('Unexpected expression!'); + } + $this->selectExpressions[] = $expr; + $this->needsSeparator = true; + $this->mostRecentHasAlias = false; + } + + public function addOption(string $option): void { + $this->options[] = $option; + } + + public function aliasRecentExpression(string $name): void { + $k = C\last_key($this->selectExpressions); + if ($k === null || $this->mostRecentHasAlias) { + throw new SQLFakeParseException('Unexpected AS'); + } + $this->selectExpressions[$k]->name = $name; + $this->mostRecentHasAlias = true; + } + + public function addMultiQuery(MultiOperand $type, SelectQuery $query): void { + $this->multiQueries[] = shape('type' => $type, 'query' => $query); + } + + /** + * Run the query + * The 2nd parameter is for supporting correlated subqueries, not currently supported + */ + public function execute(AsyncMysqlConnection $conn, ?row $_ = null): dataset { + + return + // FROM clause handling - builds a data set including extracting rows from tables, applying joins + $this->applyFrom($conn) + // WHERE caluse - filter out any rows that don't match it + |> $this->applyWhere($conn, $$) + // GROUP BY clause - may group the rows if necessary. all clauses after this need to know how to handled both grouped and ungrouped inputs + |> $this->applyGroupBy($conn, $$) + // HAVING clause, filter out any rows not matching it + |> $this->applyHaving($conn, $$) + // SELECT clause. this is where we actually run the expressions in the SELECT + |> $this->applySelect($conn, $$) + // ORDER BY. this runs after select because it could use expressions from the select + |> $this->applyOrderBy($conn, $$) + // LIMIT clause + |> $this->applyLimit($$) + // filter out any data that we needed for the ORDER BY that is not supposed to be returned + |> $this->removeOrderByExtras($conn, $$) + // this recurses in case there are any UNION, EXCEPT, INTERSECT keywords + |> $this->processMultiQuery($conn, $$); + } + + /** + * The FROM clause of the query gets processed first, retrieving data from tables, executing subqueries, and handling joins + * This is also where we build up the $columns list which is commonly used throughout the entire library to map column references to indexes in this dataset + */ + protected function applyFrom(AsyncMysqlConnection $conn): dataset { + + $from = $this->fromClause; + if ($from === null) { + // we put one empty row when there is no FROM so that queries like "SELECT 1" will return a row + return vec[dict[]]; + } + + return $from->process($conn, $this->sql); + } + + /** + * Apply the GROUP BY clause to group rows by a set of expressions. + * This may also group the rows if the select list contains an aggregate function, which requires an implicit grouping + */ + protected function applyGroupBy(AsyncMysqlConnection $conn, dataset $data): dataset { + $group_by = $this->groupBy; + $select_expressions = $this->selectExpressions; + if ($group_by !== null) { + $grouped_data = dict[]; + foreach ($data as $row) { + $hashes = ''; + foreach ($group_by as $expr) { + $hashes .= \sha1((string)$expr->evaluate($row, $conn)); + } + $hash = \sha1($hashes); + if (!C\contains_key($grouped_data, $hash)) { + $grouped_data[$hash] = dict[]; + } + $count = C\count($grouped_data[$hash]); + $grouped_data[$hash][(string)$count] = $row; + } + + $data = vec($grouped_data); + } else { + $found_aggregate = false; + foreach ($select_expressions as $expr) { + if ($expr is FunctionExpression && $expr->isAggregate()) { + $found_aggregate = true; + break; + } + } + + // if we have an aggregate function in the select clause but no group by, do an implicit group that puts all rows in one grouping + // this makes things like "SELECT COUNT(*) FROM mytable" work + if ($found_aggregate) { + return vec[Dict\map_keys($data, $k ==> (string)$k)]; + } + } + + // vec[dict[0 => dict['0' => $row, '1' => $row], 1 => dict['0' => $row]] + + return $data; + } + + /** + * Apply the HAVING clause to every (maybe grouped) row in the data set. Only return truthy results. + */ + protected function applyHaving(AsyncMysqlConnection $conn, dataset $data): dataset { + $havingClause = $this->havingClause; + if ($havingClause is nonnull) { + return Vec\filter($data, $row ==> (bool)$havingClause->evaluate($row, $conn)); + } + + return $data; + } + + /** + * Generate the result set containing SELECT expressions + */ + protected function applySelect(AsyncMysqlConnection $conn, dataset $data): dataset { + + // The ORDER BY portion of queries run after this SELECT code. + // However, it is possible to ORDER BY a field that we do not intend to SELECT by, + // or put another way: that we do not intend to return from the query. + // + // But we have to include those here anyway so we can perform the ORDER BY + // and then throw them away. + + $order_by_expressions = $this->orderBy ?? vec[]; + + $out = vec[]; + + // ok now you got that filter, let's do the formatting + foreach ($data as $row) { + $formatted_row = dict[]; + + foreach ($this->selectExpressions as $expr) { + if ($expr is ColumnExpression && $expr->name === '*') { + // if it's a GROUP BY, take the first row from each grouping. + // SELECT * with a GROUP BY effectively picks the first row from each group + $first_value = C\first($row); + if ($first_value is dict<_, _>) { + $row = $first_value; + } + foreach ($row as $col => $val) { + $parts = Str\split((string)$col, '.'); + if ($expr->tableName() is nonnull) { + list($col_table_name, $col_name) = $parts; + if ($col_table_name == $expr->tableName()) { + if (!C\contains_key($formatted_row, $col)) { + $formatted_row[$col_name] = $val; + } + } + } else { + $col = C\last($parts); + if ($col is nonnull) { + $formatted_row[$col] ??= $val; + } + } + + } + continue; + } + + list($name, $val) = $expr->evaluateWithName($row, $conn); + + // subquery: unroll the expression to get the value out + if ($expr is SubqueryExpression) { + invariant($val is KeyedContainer<_, _>, 'subquery results must be KeyedContainer'); + if (C\count($val) > 1) { + throw new SQLFakeRuntimeException('Subquery returned more than one row'); + } + if (C\count($val) === 0) { + $val = null; + } else { + foreach ($val as $r) { + $r as KeyedContainer<_, _>; + if (C\count($r) !== 1) { + throw new SQLFakeRuntimeException('Subquery result should contain 1 column'); + } + $val = C\onlyx($r); + } + } + } + $formatted_row[$name] = $val; + } + + // Adding any fields needed by the ORDER BY not already returned by the SELECT + foreach ($order_by_expressions as $order_by) { + $row as dict<_, _>; + list($name, $val) = $order_by['expression']->evaluateWithName(/* HH_FIXME[4110] generics */ $row, $conn); + $formatted_row[$name] ??= $val; + } + + $out[] = $formatted_row; + } + + if (C\contains_key($this->options, 'DISTINCT')) { + return Vec\unique_by($out, (row $row): string ==> Str\join(Vec\map($row, $col ==> (string)$col), '-')); + } + + return $out; + } + + /** + * Remove fields that we do not SELECT by, but we do ORDER BY + */ + protected function removeOrderByExtras(AsyncMysqlConnection $_conn, dataset $data): dataset { + + $order_by = $this->orderBy; + if ($order_by === null || C\count($data) === 0) { + return $data; + } + + $order_by_names = keyset[]; + $select_field_names = keyset[]; + + foreach ($this->selectExpressions as $expr) { + $name = $expr->name; + // if we are selecting everything we know the field is included + if ($name === '*') { + return $data; + } + if ($name !== null) { + $select_field_names[] = $name; + } + } + + foreach ($order_by as $o) { + $name = $o['expression']->name; + if ($name !== null) { + $order_by_names[] = $name; + } + } + + $remove_fields = Keyset\diff($order_by_names, $select_field_names); + if (C\is_empty($remove_fields)) { + return $data; + } + + // remove the fields we don't want from each row + return Vec\map($data, $row ==> Dict\filter_keys($row, $field ==> !C\contains_key($remove_fields, $field))); + } + + /** + * Process a query that contains multiple queries such as with UNION, INTERSECT, EXCEPT, UNION ALL + */ + protected function processMultiQuery(AsyncMysqlConnection $conn, dataset $data): dataset { + + // function used to stringify rows for comparison + $row_encoder = (row $row): string ==> Str\join(Vec\map($row, $col ==> (string)$col), '-'); + + foreach ($this->multiQueries as $sub) { + // invoke the subquery + $subquery_results = $sub['query']->execute($conn); + + // now put the results together based on the keyword + switch ($sub['type']) { + case MultiOperand::UNION: + // contact the results, then get unique rows by converting all fields to string and comparing a joined-up representation + $data = Vec\concat($data, $subquery_results) |> Vec\unique_by($$, $row_encoder); + break; + case MultiOperand::UNION_ALL: + // just concatenate with no uniqueness + $data = Vec\concat($data, $subquery_results); + break; + case MultiOperand::INTERSECT: + // there's no Vec\intersect_by currently + $encoded_data = Keyset\map($data, $row_encoder); + $data = Vec\filter($subquery_results, $row ==> C\contains_key($encoded_data, $row_encoder($row))); + break; + case MultiOperand::EXCEPT: + $data = Vec\diff_by($data, $subquery_results, $row_encoder); + break; + } + } + + return $data; + } } diff --git a/src/Query/UpdateQuery.php b/src/Query/UpdateQuery.php index 00ea07d..cca5b57 100644 --- a/src/Query/UpdateQuery.php +++ b/src/Query/UpdateQuery.php @@ -4,30 +4,30 @@ final class UpdateQuery extends Query { - public function __construct(public from_table $updateClause, public string $sql, public bool $ignoreDupes) {} - - public vec $setClause = vec[]; - - public function execute(AsyncMysqlConnection $conn): int { - list($tableName, $database, $data) = $this->processUpdateClause($conn); - Metrics::trackQuery(QueryType::UPDATE, $conn->getServer()->name, $tableName, $this->sql); - $schema = QueryContext::getSchema($database, $tableName); - - list($rows_affected, $_) = $this->applyWhere($conn, $data) - |> $this->applyOrderBy($conn, $$) - |> $this->applyLimit($$) - |> $this->applySet($conn, $database, $tableName, $$, $data, $this->setClause, $schema); - - return $rows_affected; - } - - /** - * process the UPDATE clause to retrieve the table - * add a row identifier to each element in the result which we can later use to update the underlying table - */ - protected function processUpdateClause(AsyncMysqlConnection $conn): (string, string, dataset) { - list($database, $tableName) = Query::parseTableName($conn, $this->updateClause['name']); - $table = $conn->getServer()->getTable($database, $tableName) ?? vec[]; - return tuple($tableName, $database, $table); - } + public function __construct(public from_table $updateClause, public string $sql, public bool $ignoreDupes) {} + + public vec $setClause = vec[]; + + public function execute(AsyncMysqlConnection $conn): int { + list($tableName, $database, $data) = $this->processUpdateClause($conn); + Metrics::trackQuery(QueryType::UPDATE, $conn->getServer()->name, $tableName, $this->sql); + $schema = QueryContext::getSchema($database, $tableName); + + list($rows_affected, $_) = $this->applyWhere($conn, $data) + |> $this->applyOrderBy($conn, $$) + |> $this->applyLimit($$) + |> $this->applySet($conn, $database, $tableName, $$, $data, $this->setClause, $schema); + + return $rows_affected; + } + + /** + * process the UPDATE clause to retrieve the table + * add a row identifier to each element in the result which we can later use to update the underlying table + */ + protected function processUpdateClause(AsyncMysqlConnection $conn): (string, string, dataset) { + list($database, $tableName) = Query::parseTableName($conn, $this->updateClause['name']); + $table = $conn->getServer()->getTable($database, $tableName) ?? vec[]; + return tuple($tableName, $database, $table); + } } diff --git a/src/QueryContext.php b/src/QueryContext.php index e0ef79a..da1ec22 100644 --- a/src/QueryContext.php +++ b/src/QueryContext.php @@ -4,50 +4,50 @@ abstract final class QueryContext { - /** - * In strict mode, any query referencing a table not in the shcema - * will throw an exception - * - * This should be turned on if schema is available - */ - public static bool $strictSchemaMode = false; - - /** - * Emulate MySQL strict SQL mode. Invalid values for columns will - * throw instead of silently coercing the data - */ - public static bool $strictSQLMode = false; - - /** - * Set to true to allow unique key violations to be ignored temporarily - * May be useful when importing test data - */ - public static bool $relaxUniqueConstraints = false; - - /** - * 1: quiet, print nothing - * 2: verbose, print every query as it executes - * 3: very verbose, print query results as well - */ - public static Verbosity $verbosity = Verbosity::QUIET; - - /** - * Set to true to skip validating if this query will - * work on vitess - */ - public static bool $skipVitessValidation = false; - - /** - * Representation of database schema - * String keys are database names, with table names inside that list - * - * There is a built-in assumption that you don't have two databases on different - * servers with the same name but different schemas. We don't include server hostnames here - * because it's common to have sharded databases with the same names on different hosts - */ - public static dict> $schema = dict[]; - - public static function getSchema(string $database, string $table): ?table_schema { - return self::$schema[$database][$table] ?? null; - } + /** + * In strict mode, any query referencing a table not in the shcema + * will throw an exception + * + * This should be turned on if schema is available + */ + public static bool $strictSchemaMode = false; + + /** + * Emulate MySQL strict SQL mode. Invalid values for columns will + * throw instead of silently coercing the data + */ + public static bool $strictSQLMode = false; + + /** + * Set to true to allow unique key violations to be ignored temporarily + * May be useful when importing test data + */ + public static bool $relaxUniqueConstraints = false; + + /** + * 1: quiet, print nothing + * 2: verbose, print every query as it executes + * 3: very verbose, print query results as well + */ + public static Verbosity $verbosity = Verbosity::QUIET; + + /** + * Set to true to skip validating if this query will + * work on vitess + */ + public static bool $skipVitessValidation = false; + + /** + * Representation of database schema + * String keys are database names, with table names inside that list + * + * There is a built-in assumption that you don't have two databases on different + * servers with the same name but different schemas. We don't include server hostnames here + * because it's common to have sharded databases with the same names on different hosts + */ + public static dict> $schema = dict[]; + + public static function getSchema(string $database, string $table): ?table_schema { + return self::$schema[$database][$table] ?? null; + } } diff --git a/src/SQLCommandProcessor.php b/src/SQLCommandProcessor.php index 58cf8f6..e2bfdce 100644 --- a/src/SQLCommandProcessor.php +++ b/src/SQLCommandProcessor.php @@ -10,38 +10,38 @@ */ abstract final class SQLCommandProcessor { - public static function execute(string $sql, AsyncMysqlConnection $conn): (dataset, int) { - - // Check for unsupported statements - if (Str\starts_with_ci($sql, 'SET') || Str\starts_with_ci($sql, 'BEGIN') || Str\starts_with_ci($sql, 'COMMIT')) { - // we don't do any handling for these kinds of statements currently - return tuple(vec[], 0); - } - - if (Str\starts_with_ci($sql, 'ROLLBACK')) { - // unlike BEGIN and COMMIT, this actually needs to have material effect on the observed behavior - // even in a single test case, and so we need to throw since it's not implemented yet - // there's no reason we couldn't start supporting transactions in the future, just haven't done the work yet - throw new SQLFakeNotImplementedException('Transactions are not yet supported'); - } - - $query = SQLParser::parse($sql); - - $is_vitess_query = $conn->getServer()->config['is_vitess'] ?? false; - if ($is_vitess_query && !QueryContext::$skipVitessValidation) { - VitessQueryValidator::validate($query, $conn); - } - - if ($query is SelectQuery) { - return tuple($query->execute($conn), 0); - } else if ($query is UpdateQuery) { - return tuple(vec[], $query->execute($conn)); - } else if ($query is DeleteQuery) { - return tuple(vec[], $query->execute($conn)); - } else if ($query is InsertQuery) { - return tuple(vec[], $query->execute($conn)); - } else { - throw new SQLFakeNotImplementedException('Unhandled query type: '.\get_class($query)); - } - } + public static function execute(string $sql, AsyncMysqlConnection $conn): (dataset, int) { + + // Check for unsupported statements + if (Str\starts_with_ci($sql, 'SET') || Str\starts_with_ci($sql, 'BEGIN') || Str\starts_with_ci($sql, 'COMMIT')) { + // we don't do any handling for these kinds of statements currently + return tuple(vec[], 0); + } + + if (Str\starts_with_ci($sql, 'ROLLBACK')) { + // unlike BEGIN and COMMIT, this actually needs to have material effect on the observed behavior + // even in a single test case, and so we need to throw since it's not implemented yet + // there's no reason we couldn't start supporting transactions in the future, just haven't done the work yet + throw new SQLFakeNotImplementedException('Transactions are not yet supported'); + } + + $query = SQLParser::parse($sql); + + $is_vitess_query = $conn->getServer()->config['is_vitess'] ?? false; + if ($is_vitess_query && !QueryContext::$skipVitessValidation) { + VitessQueryValidator::validate($query, $conn); + } + + if ($query is SelectQuery) { + return tuple($query->execute($conn), 0); + } else if ($query is UpdateQuery) { + return tuple(vec[], $query->execute($conn)); + } else if ($query is DeleteQuery) { + return tuple(vec[], $query->execute($conn)); + } else if ($query is InsertQuery) { + return tuple(vec[], $query->execute($conn)); + } else { + throw new SQLFakeNotImplementedException('Unhandled query type: '.\get_class($query)); + } + } } diff --git a/src/SchemaGenerator.php b/src/SchemaGenerator.php index 885fdce..647e903 100644 --- a/src/SchemaGenerator.php +++ b/src/SchemaGenerator.php @@ -34,7 +34,7 @@ public function generateFromString(string $sql): dict { if ($default is nonnull && $default !== 'NULL') { $f['default'] = $default; } - + $unsigned = ($field['unsigned'] ?? null); if ($unsigned is nonnull) { $f['unsigned'] = $unsigned; diff --git a/src/Server.php b/src/Server.php index 023baf8..4a12cc3 100644 --- a/src/Server.php +++ b/src/Server.php @@ -6,124 +6,124 @@ final class Server { - public function __construct(public string $name, public ?server_config $config = null) {} - - private static dict $instances = dict[]; - private static keyset $snapshot_names = keyset[]; - - public static function getAll(): dict { - return static::$instances; - } - - public static function getAllTables(): dict>>>> { - return Dict\map(static::getAll(), ($server) ==> { - return $server->databases; - }); - } - - /** - * This will override everything in $instances - */ - public static function setAll(dict $instances): void { - static::$instances = $instances; - } - - public static function get(string $name): ?this { - return static::$instances[$name] ?? null; - } - - public function setConfig(server_config $config): void { - $this->config = $config; - } - - public static function getOrCreate(string $name): this { - $server = static::$instances[$name] ?? null; - if ($server === null) { - $server = new static($name); - static::$instances[$name] = $server; - } - return $server; - } - - public static function cloneByName(string $name, string $clone_name): this { - $clone = static::get($clone_name); - if ($clone === null) { - throw new SQLFakeRuntimeException("Server $clone_name not found, unable to clone databases and snapshots"); - } - - $server = static::getOrCreate($name); - $server->databases = $clone->databases; - $server->snapshots = $clone->snapshots; - return $server; - } - - public static function reset(): void { - foreach (static::getAll() as $server) { - $server->doReset(); - } - } - - public static function snapshot(string $name): void { - foreach (static::getAll() as $server) { - $server->doSnapshot($name); - } - static::$snapshot_names[] = $name; - } - - public static function restore(string $name): void { - if (!C\contains_key(static::$snapshot_names, $name)) { - throw new SQLFakeRuntimeException("Snapshot $name not found, unable to restore"); - } - foreach (static::getAll() as $server) { - $server->doRestore($name); - } - } - - protected function doSnapshot(string $name): void { - $this->snapshots[$name] = $this->databases; - } - - protected function doRestore(string $name): void { - $this->databases = $this->snapshots[$name] ?? dict[]; - } - - protected function doReset(): void { - $this->databases = dict[]; - } - - /** - * The main storage mechanism - * dict of strings (database schema names) - * -> dict of string table names to tables - * -> vec of rows - * -> dict of string column names to columns - * - * While a structure based on objects all the way down the stack may be more powerful and readable, - * This structure uses value types intentionally, to enable a relatively efficient reset/snapshot logic - * which is often used frequently between test cases - */ - public dict>>> $databases = dict[]; - private dict>>>> $snapshots = dict[]; - - /** - * Retrieve a table from the specified database, if it exists, by value - */ - public function getTable(string $dbname, string $name): ?vec> { - return $this->databases[$dbname][$name] ?? null; - } - - /** - * Save a table's rows back to the database - * note, because insert and update operations already grab the full table for checking constraints, - * we don't bother providing an insert or update helper here. - */ - public function saveTable(string $dbname, string $name, vec> $rows): void { - // create table if not exists - if (!C\contains_key($this->databases, $dbname)) { - $this->databases[$dbname] = dict[]; - } - - // save rows - $this->databases[$dbname][$name] = $rows; - } + public function __construct(public string $name, public ?server_config $config = null) {} + + private static dict $instances = dict[]; + private static keyset $snapshot_names = keyset[]; + + public static function getAll(): dict { + return static::$instances; + } + + public static function getAllTables(): dict>>>> { + return Dict\map(static::getAll(), ($server) ==> { + return $server->databases; + }); + } + + /** + * This will override everything in $instances + */ + public static function setAll(dict $instances): void { + static::$instances = $instances; + } + + public static function get(string $name): ?this { + return static::$instances[$name] ?? null; + } + + public function setConfig(server_config $config): void { + $this->config = $config; + } + + public static function getOrCreate(string $name): this { + $server = static::$instances[$name] ?? null; + if ($server === null) { + $server = new static($name); + static::$instances[$name] = $server; + } + return $server; + } + + public static function cloneByName(string $name, string $clone_name): this { + $clone = static::get($clone_name); + if ($clone === null) { + throw new SQLFakeRuntimeException("Server $clone_name not found, unable to clone databases and snapshots"); + } + + $server = static::getOrCreate($name); + $server->databases = $clone->databases; + $server->snapshots = $clone->snapshots; + return $server; + } + + public static function reset(): void { + foreach (static::getAll() as $server) { + $server->doReset(); + } + } + + public static function snapshot(string $name): void { + foreach (static::getAll() as $server) { + $server->doSnapshot($name); + } + static::$snapshot_names[] = $name; + } + + public static function restore(string $name): void { + if (!C\contains_key(static::$snapshot_names, $name)) { + throw new SQLFakeRuntimeException("Snapshot $name not found, unable to restore"); + } + foreach (static::getAll() as $server) { + $server->doRestore($name); + } + } + + protected function doSnapshot(string $name): void { + $this->snapshots[$name] = $this->databases; + } + + protected function doRestore(string $name): void { + $this->databases = $this->snapshots[$name] ?? dict[]; + } + + protected function doReset(): void { + $this->databases = dict[]; + } + + /** + * The main storage mechanism + * dict of strings (database schema names) + * -> dict of string table names to tables + * -> vec of rows + * -> dict of string column names to columns + * + * While a structure based on objects all the way down the stack may be more powerful and readable, + * This structure uses value types intentionally, to enable a relatively efficient reset/snapshot logic + * which is often used frequently between test cases + */ + public dict>>> $databases = dict[]; + private dict>>>> $snapshots = dict[]; + + /** + * Retrieve a table from the specified database, if it exists, by value + */ + public function getTable(string $dbname, string $name): ?vec> { + return $this->databases[$dbname][$name] ?? null; + } + + /** + * Save a table's rows back to the database + * note, because insert and update operations already grab the full table for checking constraints, + * we don't bother providing an insert or update helper here. + */ + public function saveTable(string $dbname, string $name, vec> $rows): void { + // create table if not exists + if (!C\contains_key($this->databases, $dbname)) { + $this->databases[$dbname] = dict[]; + } + + // save rows + $this->databases[$dbname][$name] = $rows; + } } diff --git a/src/Types.php b/src/Types.php index bf3916b..645e06b 100644 --- a/src/Types.php +++ b/src/Types.php @@ -18,75 +18,75 @@ // type token = shape( - 'type' => TokenType, - 'value' => string, - // the raw token including capitalization, quoting, and whitespace. used for generating SELECT column names for expressions - 'raw' => string, + 'type' => TokenType, + 'value' => string, + // the raw token including capitalization, quoting, and whitespace. used for generating SELECT column names for expressions + 'raw' => string, ); enum TokenType: string { - NUMERIC_CONSTANT = 'Number'; - STRING_CONSTANT = 'String'; - CLAUSE = 'Clause'; - OPERATOR = 'Operator'; - RESERVED = 'Reserved'; - PAREN = 'Paren'; - SEPARATOR = 'Separator'; - SQLFUNCTION = 'Function'; - IDENTIFIER = 'Identifier'; - NULL_CONSTANT = 'Null'; - BOOLEAN_CONSTANT = 'Boolean'; + NUMERIC_CONSTANT = 'Number'; + STRING_CONSTANT = 'String'; + CLAUSE = 'Clause'; + OPERATOR = 'Operator'; + RESERVED = 'Reserved'; + PAREN = 'Paren'; + SEPARATOR = 'Separator'; + SQLFUNCTION = 'Function'; + IDENTIFIER = 'Identifier'; + NULL_CONSTANT = 'Null'; + BOOLEAN_CONSTANT = 'Boolean'; } enum JoinType: string { - JOIN = 'JOIN'; - LEFT = 'LEFT'; - RIGHT = 'RIGHT'; - CROSS = 'CROSS'; - STRAIGHT = 'STRAIGHT_JOIN'; - NATURAL = 'NATURAL'; + JOIN = 'JOIN'; + LEFT = 'LEFT'; + RIGHT = 'RIGHT'; + CROSS = 'CROSS'; + STRAIGHT = 'STRAIGHT_JOIN'; + NATURAL = 'NATURAL'; } enum Verbosity: int as int { - // Default, print nothing - QUIET = 1; - // Print every query as it executes - QUERIES = 2; - // Print every query and its results - RESULTS = 3; + // Default, print nothing + QUIET = 1; + // Print every query as it executes + QUERIES = 2; + // Print every query and its results + RESULTS = 3; } enum JoinOperator: string { - ON = 'ON'; - USING = 'USING'; + ON = 'ON'; + USING = 'USING'; } enum SortDirection: string { - ASC = 'ASC'; - DESC = 'DESC'; + ASC = 'ASC'; + DESC = 'DESC'; } enum MultiOperand: string { - UNION = 'UNION'; - UNION_ALL = 'UNION_ALL'; - EXCEPT = 'EXCEPT'; - INTERSECT = 'INTERSECT'; + UNION = 'UNION'; + UNION_ALL = 'UNION_ALL'; + EXCEPT = 'EXCEPT'; + INTERSECT = 'INTERSECT'; } type token_list = vec; type from_table = shape( - 'name' => string, - ?'subquery' => SubqueryExpression, - 'join_type' => JoinType, - ?'join_operator' => JoinOperator, - ?'alias' => string, - ?'join_expression' => ?Expression, + 'name' => string, + ?'subquery' => SubqueryExpression, + 'join_type' => JoinType, + ?'join_operator' => JoinOperator, + ?'alias' => string, + ?'join_expression' => ?Expression, ); type limit_clause = shape( - 'rowcount' => int, - 'offset' => int, + 'rowcount' => int, + 'offset' => int, ); type order_by_clause = vec Expression, 'direction' => SortDirection)>; @@ -98,120 +98,120 @@ enum MultiOperand: string { */ type table_schema = shape( - /** - * Table name as it exists in the database - */ - 'name' => string, - 'fields' => Container< - shape( - 'name' => string, - 'type' => DataType, - 'length' => int, - 'null' => bool, - 'hack_type' => string, - ?'unsigned' => bool, - ?'default' => string, - ), - >, - 'indexes' => Container< - shape( - 'name' => string, - 'type' => string, - 'fields' => Container, - ), - >, - ?'vitess_sharding' => shape( - 'keyspace' => string, - 'sharding_key' => string, - ), + /** + * Table name as it exists in the database + */ + 'name' => string, + 'fields' => Container< + shape( + 'name' => string, + 'type' => DataType, + 'length' => int, + 'null' => bool, + 'hack_type' => string, + ?'unsigned' => bool, + ?'default' => string, + ), + >, + 'indexes' => Container< + shape( + 'name' => string, + 'type' => string, + 'fields' => Container, + ), + >, + ?'vitess_sharding' => shape( + 'keyspace' => string, + 'sharding_key' => string, + ), ); enum DataType: string { - TINYINT = 'TINYINT'; - SMALLINT = 'SMALLINT'; - MEDIUMINT = 'MEDIUMINT'; - INT = 'INT'; - BIT = 'BIT'; - BIGINT = 'BIGINT'; - FLOAT = 'FLOAT'; - DOUBLE = 'DOUBLE'; - BINARY = 'BINARY'; - CHAR = 'CHAR'; - ENUM = 'ENUM'; - SET = 'SET'; - TINYBLOB = 'TINYBLOB'; - BLOB = 'BLOB'; - MEDIUMBLOB = 'MEDIUMBLOB'; - LONGBLOB = 'LONGBLOB'; - TEXT = 'TEXT'; - TINYTEXT = 'TINYTEXT'; - MEDIUMTEXT = 'MEDIUMTEXT'; - LONGTEXT = 'LONGTEXT'; - VARCHAR = 'VARCHAR'; - VARBINARY = 'VARBINARY'; - JSON = 'JSON'; - DATE = 'DATE'; - DATETIME = 'DATETIME'; - TIME = 'TIME'; - YEAR = 'YEAR'; - TIMESTAMP = 'TIMESTAMP'; - DECIMAL = 'DECIMAL'; - NUMERIC = 'NUMERIC'; + TINYINT = 'TINYINT'; + SMALLINT = 'SMALLINT'; + MEDIUMINT = 'MEDIUMINT'; + INT = 'INT'; + BIT = 'BIT'; + BIGINT = 'BIGINT'; + FLOAT = 'FLOAT'; + DOUBLE = 'DOUBLE'; + BINARY = 'BINARY'; + CHAR = 'CHAR'; + ENUM = 'ENUM'; + SET = 'SET'; + TINYBLOB = 'TINYBLOB'; + BLOB = 'BLOB'; + MEDIUMBLOB = 'MEDIUMBLOB'; + LONGBLOB = 'LONGBLOB'; + TEXT = 'TEXT'; + TINYTEXT = 'TINYTEXT'; + MEDIUMTEXT = 'MEDIUMTEXT'; + LONGTEXT = 'LONGTEXT'; + VARCHAR = 'VARCHAR'; + VARBINARY = 'VARBINARY'; + JSON = 'JSON'; + DATE = 'DATE'; + DATETIME = 'DATETIME'; + TIME = 'TIME'; + YEAR = 'YEAR'; + TIMESTAMP = 'TIMESTAMP'; + DECIMAL = 'DECIMAL'; + NUMERIC = 'NUMERIC'; } enum Operator: string { - AMPERSAND = '&'; - AND = 'AND'; - ANY = 'ANY'; - ASTERISK = '*'; - BANG = '!'; - BANG_EQUALS = '!='; - BETWEEN = 'BETWEEN'; - BINARY = 'BINARY'; - CASE = 'CASE'; - CARET = '^'; - COLLATE = 'COLLATE'; - DIV = 'DIV'; - DOUBLE_AMPERSAND = '&&'; - DOUBLE_GREATER_THAN = '>>'; - DOUBLE_LESS_THAN = '<<'; - DOUBLE_PIPE = '||'; - ELSE = 'ELSE'; - END = 'END'; - EQUALS = '='; - EXISTS = 'EXISTS'; - FORWARD_SLASH = '/'; - GREATER_THAN = '>'; - GREATER_THAN_EQUALS = '>='; - LESS_THAN = '<'; - LESS_THAN_EQUALS = '<='; - LESS_THAN_EQUALS_GREATER_THAN = '<=>'; - LESS_THAN_GREATER_THAN = '<>'; - LIKE = 'LIKE'; - IN = 'IN'; - INTERVAL = 'INTERVAL'; - IS = 'IS'; - MOD = 'MOD'; - MINUS = '-'; - NOT = 'NOT'; - OR = 'OR'; - PERCENT = '%'; - PIPE = '|'; - PLUS = '+'; - RLIKE = 'RLIKE'; - REGEXP = 'REGEXP'; - SOME = 'SOME'; - SOUNDS = 'SOUNDS'; - THEN = 'THEN'; - TILDE = '~'; - WHEN = 'WHEN'; - UNARY_MINUS = 'UNARY_MINUS'; - UNARY_PLUS = 'UNARY_PLUS'; - XOR = 'XOR'; + AMPERSAND = '&'; + AND = 'AND'; + ANY = 'ANY'; + ASTERISK = '*'; + BANG = '!'; + BANG_EQUALS = '!='; + BETWEEN = 'BETWEEN'; + BINARY = 'BINARY'; + CASE = 'CASE'; + CARET = '^'; + COLLATE = 'COLLATE'; + DIV = 'DIV'; + DOUBLE_AMPERSAND = '&&'; + DOUBLE_GREATER_THAN = '>>'; + DOUBLE_LESS_THAN = '<<'; + DOUBLE_PIPE = '||'; + ELSE = 'ELSE'; + END = 'END'; + EQUALS = '='; + EXISTS = 'EXISTS'; + FORWARD_SLASH = '/'; + GREATER_THAN = '>'; + GREATER_THAN_EQUALS = '>='; + LESS_THAN = '<'; + LESS_THAN_EQUALS = '<='; + LESS_THAN_EQUALS_GREATER_THAN = '<=>'; + LESS_THAN_GREATER_THAN = '<>'; + LIKE = 'LIKE'; + IN = 'IN'; + INTERVAL = 'INTERVAL'; + IS = 'IS'; + MOD = 'MOD'; + MINUS = '-'; + NOT = 'NOT'; + OR = 'OR'; + PERCENT = '%'; + PIPE = '|'; + PLUS = '+'; + RLIKE = 'RLIKE'; + REGEXP = 'REGEXP'; + SOME = 'SOME'; + SOUNDS = 'SOUNDS'; + THEN = 'THEN'; + TILDE = '~'; + WHEN = 'WHEN'; + UNARY_MINUS = 'UNARY_MINUS'; + UNARY_PLUS = 'UNARY_PLUS'; + XOR = 'XOR'; } function operator_to_string(Operator $o): string { - return $o as string; + return $o as string; } /** @@ -220,28 +220,28 @@ function operator_to_string(Operator $o): string { * If Operators flow into string typed code, this conversion needs to happen again. */ function operatorn_to_string(?Operator $o): string { - return $o as ?string ?? ''; + return $o as ?string ?? ''; } type server_config = shape( - // i.e. 5.6, 5.7 - 'mysql_version' => string, - ?'is_vitess' => bool, - 'strict_sql_mode' => bool, - 'strict_schema_mode' => bool, - // name of a database in table configuration to copy schema from - ?'inherit_schema_from' => string, + // i.e. 5.6, 5.7 + 'mysql_version' => string, + ?'is_vitess' => bool, + 'strict_sql_mode' => bool, + 'strict_schema_mode' => bool, + // name of a database in table configuration to copy schema from + ?'inherit_schema_from' => string, ); // Wrapped values to retain original type information class WrappedJSON { - public function __construct(private mixed $json) {} + public function __construct(private mixed $json) {} - public function asString(): string { - return \json_encode($this->json); - } + public function asString(): string { + return \json_encode($this->json); + } - public function rawValue(): mixed { - return $this->json; - } + public function rawValue(): mixed { + return $this->json; + } } diff --git a/src/VitessQueryValidator.php b/src/VitessQueryValidator.php index 4a91bd0..e8e69c6 100644 --- a/src/VitessQueryValidator.php +++ b/src/VitessQueryValidator.php @@ -75,8 +75,7 @@ public static function enablePrimaryVindexColumnValidator(): void { public function getHandlers(): dict)> { $handlers = dict[]; if (static::$isPrimaryVindexColumnValidatorEnabled) { - $handlers[UnsupportedCases::PRIMARY_VINDEX_COLUMN] = async () ==> - await $this->updateChangesPrimaryVindexColumn(); + $handlers[UnsupportedCases::PRIMARY_VINDEX_COLUMN] = async () ==> await $this->updateChangesPrimaryVindexColumn(); } return $handlers; @@ -205,9 +204,8 @@ private function isCrossShardQuery(): bool { case MultiOperand::UNION_ALL: case MultiOperand::INTERSECT: case MultiOperand::EXCEPT: - throw new SQLFakeVitessQueryViolation( - Str\format('Vitess query validation error: %s', UnsupportedCases::UNIONS), - ); + throw + new SQLFakeVitessQueryViolation(Str\format('Vitess query validation error: %s', UnsupportedCases::UNIONS)); } } } diff --git a/tests/InsertQueryTest.php b/tests/InsertQueryTest.php index 60880b0..2782131 100644 --- a/tests/InsertQueryTest.php +++ b/tests/InsertQueryTest.php @@ -7,393 +7,393 @@ final class InsertQueryTest extends HackTest { - private static ?AsyncMysqlConnection $conn; - - <<__Override>> - public static async function beforeFirstTestAsync(): Awaitable { - init(TEST_SCHEMA, true); - $pool = new AsyncMysqlConnectionPool(darray[]); - static::$conn = await $pool->connect('example', 1, 'db1', '', ''); - // black hole logging - Logger::setHandle(new \HH\Lib\IO\MemoryHandle()); - } - - <<__Override>> - public async function beforeEachTestAsync(): Awaitable { - Server::reset(); - QueryContext::$strictSQLMode = false; - QueryContext::$strictSchemaMode = false; - } - - public async function testSingleInsert(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); - $results = await $conn->query('SELECT * FROM table1'); - expect($results->rows())->toBeSame(vec[dict['id' => 1, 'name' => 'test']]); - } - - public async function testSingleInsertBacktickIdentifiers(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO `table1` (`id`, `name`) VALUES (1, 'test')"); - $results = await $conn->query('SELECT * FROM `table1`'); - expect($results->rows())->toBeSame(vec[dict['id' => 1, 'name' => 'test']]); - } - - public async function testMultiInsert(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test'), (2, 'test2')"); - $results = await $conn->query('SELECT * FROM table1'); - expect($results->rows())->toBeSame(vec[dict['id' => 1, 'name' => 'test'], dict['id' => 2, 'name' => 'test2']]); - } - - public async function testPKViolation(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); - expect(() ==> $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test2')"))->toThrow( - SQLFakeUniqueKeyViolation::class, - "Duplicate entry '1' for key 'PRIMARY' in table 'table1'", - ); - } - - public async function testPKViolationInsertIgnore(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); - expect(() ==> $conn->query("INSERT IGNORE INTO table1 (id, name) VALUES (1, 'test2')"))->notToThrow( - SQLFakeUniqueKeyViolation::class, - ); - $results = await $conn->query('SELECT * FROM table1'); - expect($results->rows())->toBeSame( - vec[dict[ - 'id' => 1, - 'name' => 'test', - ]], - ); - } - - public async function testPKViolationWithinMultiInsert(): Awaitable { - $conn = static::$conn as nonnull; - expect(() ==> $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test'), (1, 'test2')"))->toThrow( - SQLFakeUniqueKeyViolation::class, - "Duplicate entry '1' for key 'PRIMARY' in table 'table1'", - ); - } - - public async function testUniqueViolation(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); - expect(() ==> $conn->query("INSERT INTO table1 (id, name) VALUES (2, 'test')"))->toThrow( - SQLFakeUniqueKeyViolation::class, - "Duplicate entry 'test' for key 'name_uniq' in table 'table1'", - ); - } - - public async function testUniqueViolationInsertIgnore(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); - expect(() ==> $conn->query("INSERT IGNORE INTO table1 (id, name) VALUES (2, 'test')"))->notToThrow( - SQLFakeUniqueKeyViolation::class, - ); - $results = await $conn->query('SELECT * FROM table1'); - expect($results->rows())->toBeSame( - vec[dict[ - 'id' => 1, - 'name' => 'test', - ]], - ); - } - - public async function testPartialValuesList(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test')"); - $results = await $conn->query('SELECT * FROM table_with_more_fields'); - expect($results->rows())->toBeSame(vec[dict[ - 'id' => 1, - 'name' => 'test', - 'nullable_unique' => null, - 'nullable_default' => 1, - 'not_null_default' => 2, - ]]); - } - - public async function testExplicitNullForNullableField(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', null)"); - $results = await $conn->query('SELECT * FROM table_with_more_fields'); - expect($results->rows())->toBeSame(vec[dict[ - 'id' => 1, - 'name' => 'test', - 'nullable_unique' => null, - 'nullable_default' => 1, - 'not_null_default' => 2, - ]]); - } - - public async function testExplicitNullForNotNullableFieldStrict(): Awaitable { - $conn = static::$conn as nonnull; - QueryContext::$strictSQLMode = true; - expect( - () ==> $conn->query("INSERT INTO table_with_more_fields (id, name, not_null_default) VALUES (1, 'test', null)"), - )->toThrow( - SQLFakeRuntimeException::class, - "Column 'not_null_default' on 'table_with_more_fields' does not allow null values", - ); - } - - public async function testExplicitNullForNotNullableFieldNotStrict(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table_with_more_fields (id, name, not_null_default) VALUES (1, 'test', null)"); - $results = await $conn->query('SELECT * FROM table_with_more_fields'); - expect($results->rows())->toBeSame(vec[dict[ - 'id' => 1, - 'name' => 'test', - 'nullable_unique' => null, - 'nullable_default' => 1, - 'not_null_default' => 2, - ]]); - } - - public async function testCompoundPKNoViolation(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test')"); - expect(() ==> $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test2')"))->notToThrow( - SQLFakeUniqueKeyViolation::class, - ); - } - - public async function testCompoundPKViolation(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test')"); - expect(() ==> $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test')"))->toThrow( - SQLFakeUniqueKeyViolation::class, - "Duplicate entry '1, test' for key 'PRIMARY' in table 'table_with_more_fields'", - ); - } - - public async function testMismatchedValuesList(): Awaitable { - $conn = static::$conn as nonnull; - expect(() ==> $conn->query("INSERT INTO table1 (id, name, col3) VALUES (1, 'test2')"))->toThrow( - SQLFakeParseException::class, - 'Insert list contains 3 fields, but values clause contains 2', - ); - } - - public async function testNullableUniqueNoViolation(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example')"); - expect( - () ==> $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (2, 'test2', null)"), - ) - ->notToThrow(SQLFakeUniqueKeyViolation::class); - } - - public async function testNullableUniqueViolation(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example')"); - expect( - () ==> - $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (2, 'test2', 'example')"), - ) - ->toThrow( - SQLFakeUniqueKeyViolation::class, - "Duplicate entry 'example' for key 'nullable_unique' in table 'table_with_more_fields'", - ); - } - - public async function testMissingNotNullFieldNoDefault(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query('INSERT INTO table2 (id, table_1_id) VALUES (1, 1)'); - $results = await $conn->query('SELECT * FROM table2'); - expect($results->rows())->toBeSame(vec[dict[ - 'id' => 1, - 'table_1_id' => 1, - 'description' => '', - ]]); - } - - public async function testMissingNotNullFieldNoDefaultStrict(): Awaitable { - $conn = static::$conn as nonnull; - QueryContext::$strictSQLMode = true; - expect(() ==> $conn->query('INSERT INTO table2 (id, table_1_id) VALUES (1, 1)'))->toThrow( - SQLFakeRuntimeException::class, - "Column 'description' on 'table2' does not allow null values", - ); - } - - public async function testWrongDataType(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table2 (id, table_1_id, description) VALUES (1, 'notastring', 'test')"); - $results = await $conn->query('SELECT * FROM table2'); - expect($results->rows())->toBeSame(vec[dict[ - 'id' => 1, - 'table_1_id' => 0, - 'description' => 'test', - ]]); - } - - public async function testWrongDataTypeStrict(): Awaitable { - $conn = static::$conn as nonnull; - QueryContext::$strictSQLMode = true; - expect(() ==> $conn->query("INSERT INTO table2 (id, table_1_id, description) VALUES (1, 'notastring', 'test')")) - ->toThrow( - SQLFakeRuntimeException::class, - "Invalid value 'notastring' for column 'table_1_id' on 'table2', expected int", - ); - } - - public async function testEmptyStringInsertIntoJsonColumn(): Awaitable { - $conn = static::$conn as nonnull; - QueryContext::$strictSQLMode = true; - expect(() ==> $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, '')")) - ->toThrow( - SQLFakeRuntimeException::class, - "Invalid value '' for column 'data' on 'table_with_json', expected json", - ); - } - - public async function testInvalidJsonStringInsertIntoJsonColumn(): Awaitable { - $conn = static::$conn as nonnull; - QueryContext::$strictSQLMode = true; - expect(() ==> $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, 'abc')")) - ->toThrow( - SQLFakeRuntimeException::class, - "Invalid value 'abc' for column 'data' on 'table_with_json', expected json", - ); - } - - public async function testNullStringCapsInsertIntoJsonColumn(): Awaitable { - $conn = static::$conn as nonnull; - QueryContext::$strictSQLMode = true; - expect(() ==> $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, 'NULL')")) - ->toThrow( - SQLFakeRuntimeException::class, - "Invalid value 'NULL' for column 'data' on 'table_with_json', expected json", - ); - } - - public async function testNullStringLowercaseInsertIntoJsonColumn(): Awaitable { - $conn = static::$conn as nonnull; - QueryContext::$strictSQLMode = true; - await $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, 'null')"); - $result = await $conn->query('SELECT * FROM table_with_json'); - expect($result->rows())->toBeSame(vec[dict['id' => 1, 'data' => null]]); - } - - public async function testNullInsertIntoJsonColumn(): Awaitable { - $conn = static::$conn as nonnull; - QueryContext::$strictSQLMode = true; - await $conn->query('INSERT INTO table_with_json (id, data) VALUES (1, NULL)'); - $result = await $conn->query('SELECT * FROM table_with_json'); - expect($result->rows())->toBeSame(vec[dict['id' => 1, 'data' => null]]); - } - - public async function testValidJsonInsertIntoJsonColumn(): Awaitable { - $conn = static::$conn as nonnull; - QueryContext::$strictSQLMode = true; - await $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, '{\"test\":123}')"); - $result = await $conn->query('SELECT * FROM table_with_json'); - expect($result->rows())->toBeSame(vec[dict['id' => 1, 'data' => '{"test":123}']]); - } - - public async function testDupeInsertNoConflicts(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query( - "INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example') ON DUPLICATE KEY UPDATE nullable_default=nullable_default+1", - ); - $results = await $conn->query('SELECT * FROM table_with_more_fields'); - expect($results->rows())->toBeSame(vec[dict[ - 'id' => 1, - 'name' => 'test', - 'nullable_unique' => 'example', - 'nullable_default' => 1, - 'not_null_default' => 2, - ]]); - } - - public async function testDupeInsertWithConflicts(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query( - "INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example') ON DUPLICATE KEY UPDATE nullable_default=nullable_default+1", - ); - await $conn->query( - "INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example') ON DUPLICATE KEY UPDATE nullable_default=nullable_default+1", - ); - $results = await $conn->query('SELECT * FROM table_with_more_fields'); - expect($results->rows())->toBeSame(vec[dict[ - 'id' => 1, - 'name' => 'test', - 'nullable_unique' => 'example', - 'nullable_default' => 2, - 'not_null_default' => 2, - ]]); - } - - public async function testDupeInsertWithValuesFunction(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example')"); - await $conn->query( - "INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'new_example') ON DUPLICATE KEY UPDATE nullable_unique=VALUES(nullable_unique)", - ); - $results = await $conn->query('SELECT * FROM table_with_more_fields'); - expect($results->rows())->toBeSame(vec[dict[ - 'id' => 1, - 'name' => 'test', - 'nullable_unique' => 'new_example', - 'nullable_default' => 1, - 'not_null_default' => 2, - ]]); - } - - public async function testParseComplexWithEscapedJSONAndComment(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query( - "INSERT INTO table_with_more_fields (`id`, `name`, `nullable_unique`) VALUES (56789,'{\\\"tes\\'st\\\":\\\"12345\\\"}','test') /* SQL Comment */", - ); - $results = await $conn->query('SELECT * FROM table_with_more_fields'); - expect($results->rows())->toBeSame(vec[ - dict[ - 'id' => 56789, - 'name' => '{"tes\'st":"12345"}', - 'nullable_unique' => 'test', - 'nullable_default' => 1, - 'not_null_default' => 2, - ], - ]); - } - - public async function testDupeInsertEscaping(): Awaitable { - $conn = static::$conn as nonnull; - await $conn->query(<<<'EOT' - INSERT INTO table1 (`id`,`name`) VALUES (123456789, 'xÚdfíá()ÊÏMÊÏKòáÂÕÿfl©99ùåp>sQj¤Ø©¸¨©=)7±(I{^PSj\\%Krbv*+#©¶ Ì\0Ma\0a\0¤Ý7\\') - ON DUPLICATE KEY UPDATE `name`='xÚdfíá()ÊÏMÊÏKòáÂÕÿfl©99ùåp>sQj¤Ø©¸¨©=)7±(I{^PSj\\%Krbv*+#©¶ Ì\0Ma\0a\0¤Ý7\\' + private static ?AsyncMysqlConnection $conn; + + <<__Override>> + public static async function beforeFirstTestAsync(): Awaitable { + init(TEST_SCHEMA, true); + $pool = new AsyncMysqlConnectionPool(darray[]); + static::$conn = await $pool->connect('example', 1, 'db1', '', ''); + // black hole logging + Logger::setHandle(new \HH\Lib\IO\MemoryHandle()); + } + + <<__Override>> + public async function beforeEachTestAsync(): Awaitable { + Server::reset(); + QueryContext::$strictSQLMode = false; + QueryContext::$strictSchemaMode = false; + } + + public async function testSingleInsert(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); + $results = await $conn->query('SELECT * FROM table1'); + expect($results->rows())->toBeSame(vec[dict['id' => 1, 'name' => 'test']]); + } + + public async function testSingleInsertBacktickIdentifiers(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO `table1` (`id`, `name`) VALUES (1, 'test')"); + $results = await $conn->query('SELECT * FROM `table1`'); + expect($results->rows())->toBeSame(vec[dict['id' => 1, 'name' => 'test']]); + } + + public async function testMultiInsert(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test'), (2, 'test2')"); + $results = await $conn->query('SELECT * FROM table1'); + expect($results->rows())->toBeSame(vec[dict['id' => 1, 'name' => 'test'], dict['id' => 2, 'name' => 'test2']]); + } + + public async function testPKViolation(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); + expect(() ==> $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test2')"))->toThrow( + SQLFakeUniqueKeyViolation::class, + "Duplicate entry '1' for key 'PRIMARY' in table 'table1'", + ); + } + + public async function testPKViolationInsertIgnore(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); + expect(() ==> $conn->query("INSERT IGNORE INTO table1 (id, name) VALUES (1, 'test2')"))->notToThrow( + SQLFakeUniqueKeyViolation::class, + ); + $results = await $conn->query('SELECT * FROM table1'); + expect($results->rows())->toBeSame( + vec[dict[ + 'id' => 1, + 'name' => 'test', + ]], + ); + } + + public async function testPKViolationWithinMultiInsert(): Awaitable { + $conn = static::$conn as nonnull; + expect(() ==> $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test'), (1, 'test2')"))->toThrow( + SQLFakeUniqueKeyViolation::class, + "Duplicate entry '1' for key 'PRIMARY' in table 'table1'", + ); + } + + public async function testUniqueViolation(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); + expect(() ==> $conn->query("INSERT INTO table1 (id, name) VALUES (2, 'test')"))->toThrow( + SQLFakeUniqueKeyViolation::class, + "Duplicate entry 'test' for key 'name_uniq' in table 'table1'", + ); + } + + public async function testUniqueViolationInsertIgnore(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table1 (id, name) VALUES (1, 'test')"); + expect(() ==> $conn->query("INSERT IGNORE INTO table1 (id, name) VALUES (2, 'test')"))->notToThrow( + SQLFakeUniqueKeyViolation::class, + ); + $results = await $conn->query('SELECT * FROM table1'); + expect($results->rows())->toBeSame( + vec[dict[ + 'id' => 1, + 'name' => 'test', + ]], + ); + } + + public async function testPartialValuesList(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test')"); + $results = await $conn->query('SELECT * FROM table_with_more_fields'); + expect($results->rows())->toBeSame(vec[dict[ + 'id' => 1, + 'name' => 'test', + 'nullable_unique' => null, + 'nullable_default' => 1, + 'not_null_default' => 2, + ]]); + } + + public async function testExplicitNullForNullableField(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', null)"); + $results = await $conn->query('SELECT * FROM table_with_more_fields'); + expect($results->rows())->toBeSame(vec[dict[ + 'id' => 1, + 'name' => 'test', + 'nullable_unique' => null, + 'nullable_default' => 1, + 'not_null_default' => 2, + ]]); + } + + public async function testExplicitNullForNotNullableFieldStrict(): Awaitable { + $conn = static::$conn as nonnull; + QueryContext::$strictSQLMode = true; + expect( + () ==> $conn->query("INSERT INTO table_with_more_fields (id, name, not_null_default) VALUES (1, 'test', null)"), + )->toThrow( + SQLFakeRuntimeException::class, + "Column 'not_null_default' on 'table_with_more_fields' does not allow null values", + ); + } + + public async function testExplicitNullForNotNullableFieldNotStrict(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table_with_more_fields (id, name, not_null_default) VALUES (1, 'test', null)"); + $results = await $conn->query('SELECT * FROM table_with_more_fields'); + expect($results->rows())->toBeSame(vec[dict[ + 'id' => 1, + 'name' => 'test', + 'nullable_unique' => null, + 'nullable_default' => 1, + 'not_null_default' => 2, + ]]); + } + + public async function testCompoundPKNoViolation(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test')"); + expect(() ==> $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test2')"))->notToThrow( + SQLFakeUniqueKeyViolation::class, + ); + } + + public async function testCompoundPKViolation(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test')"); + expect(() ==> $conn->query("INSERT INTO table_with_more_fields (id, name) VALUES (1, 'test')"))->toThrow( + SQLFakeUniqueKeyViolation::class, + "Duplicate entry '1, test' for key 'PRIMARY' in table 'table_with_more_fields'", + ); + } + + public async function testMismatchedValuesList(): Awaitable { + $conn = static::$conn as nonnull; + expect(() ==> $conn->query("INSERT INTO table1 (id, name, col3) VALUES (1, 'test2')"))->toThrow( + SQLFakeParseException::class, + 'Insert list contains 3 fields, but values clause contains 2', + ); + } + + public async function testNullableUniqueNoViolation(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example')"); + expect( + () ==> $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (2, 'test2', null)"), + ) + ->notToThrow(SQLFakeUniqueKeyViolation::class); + } + + public async function testNullableUniqueViolation(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example')"); + expect( + () ==> + $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (2, 'test2', 'example')"), + ) + ->toThrow( + SQLFakeUniqueKeyViolation::class, + "Duplicate entry 'example' for key 'nullable_unique' in table 'table_with_more_fields'", + ); + } + + public async function testMissingNotNullFieldNoDefault(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query('INSERT INTO table2 (id, table_1_id) VALUES (1, 1)'); + $results = await $conn->query('SELECT * FROM table2'); + expect($results->rows())->toBeSame(vec[dict[ + 'id' => 1, + 'table_1_id' => 1, + 'description' => '', + ]]); + } + + public async function testMissingNotNullFieldNoDefaultStrict(): Awaitable { + $conn = static::$conn as nonnull; + QueryContext::$strictSQLMode = true; + expect(() ==> $conn->query('INSERT INTO table2 (id, table_1_id) VALUES (1, 1)'))->toThrow( + SQLFakeRuntimeException::class, + "Column 'description' on 'table2' does not allow null values", + ); + } + + public async function testWrongDataType(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table2 (id, table_1_id, description) VALUES (1, 'notastring', 'test')"); + $results = await $conn->query('SELECT * FROM table2'); + expect($results->rows())->toBeSame(vec[dict[ + 'id' => 1, + 'table_1_id' => 0, + 'description' => 'test', + ]]); + } + + public async function testWrongDataTypeStrict(): Awaitable { + $conn = static::$conn as nonnull; + QueryContext::$strictSQLMode = true; + expect(() ==> $conn->query("INSERT INTO table2 (id, table_1_id, description) VALUES (1, 'notastring', 'test')")) + ->toThrow( + SQLFakeRuntimeException::class, + "Invalid value 'notastring' for column 'table_1_id' on 'table2', expected int", + ); + } + + public async function testEmptyStringInsertIntoJsonColumn(): Awaitable { + $conn = static::$conn as nonnull; + QueryContext::$strictSQLMode = true; + expect(() ==> $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, '')")) + ->toThrow( + SQLFakeRuntimeException::class, + "Invalid value '' for column 'data' on 'table_with_json', expected json", + ); + } + + public async function testInvalidJsonStringInsertIntoJsonColumn(): Awaitable { + $conn = static::$conn as nonnull; + QueryContext::$strictSQLMode = true; + expect(() ==> $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, 'abc')")) + ->toThrow( + SQLFakeRuntimeException::class, + "Invalid value 'abc' for column 'data' on 'table_with_json', expected json", + ); + } + + public async function testNullStringCapsInsertIntoJsonColumn(): Awaitable { + $conn = static::$conn as nonnull; + QueryContext::$strictSQLMode = true; + expect(() ==> $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, 'NULL')")) + ->toThrow( + SQLFakeRuntimeException::class, + "Invalid value 'NULL' for column 'data' on 'table_with_json', expected json", + ); + } + + public async function testNullStringLowercaseInsertIntoJsonColumn(): Awaitable { + $conn = static::$conn as nonnull; + QueryContext::$strictSQLMode = true; + await $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, 'null')"); + $result = await $conn->query('SELECT * FROM table_with_json'); + expect($result->rows())->toBeSame(vec[dict['id' => 1, 'data' => null]]); + } + + public async function testNullInsertIntoJsonColumn(): Awaitable { + $conn = static::$conn as nonnull; + QueryContext::$strictSQLMode = true; + await $conn->query('INSERT INTO table_with_json (id, data) VALUES (1, NULL)'); + $result = await $conn->query('SELECT * FROM table_with_json'); + expect($result->rows())->toBeSame(vec[dict['id' => 1, 'data' => null]]); + } + + public async function testValidJsonInsertIntoJsonColumn(): Awaitable { + $conn = static::$conn as nonnull; + QueryContext::$strictSQLMode = true; + await $conn->query("INSERT INTO table_with_json (id, data) VALUES (1, '{\"test\":123}')"); + $result = await $conn->query('SELECT * FROM table_with_json'); + expect($result->rows())->toBeSame(vec[dict['id' => 1, 'data' => '{"test":123}']]); + } + + public async function testDupeInsertNoConflicts(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query( + "INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example') ON DUPLICATE KEY UPDATE nullable_default=nullable_default+1", + ); + $results = await $conn->query('SELECT * FROM table_with_more_fields'); + expect($results->rows())->toBeSame(vec[dict[ + 'id' => 1, + 'name' => 'test', + 'nullable_unique' => 'example', + 'nullable_default' => 1, + 'not_null_default' => 2, + ]]); + } + + public async function testDupeInsertWithConflicts(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query( + "INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example') ON DUPLICATE KEY UPDATE nullable_default=nullable_default+1", + ); + await $conn->query( + "INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example') ON DUPLICATE KEY UPDATE nullable_default=nullable_default+1", + ); + $results = await $conn->query('SELECT * FROM table_with_more_fields'); + expect($results->rows())->toBeSame(vec[dict[ + 'id' => 1, + 'name' => 'test', + 'nullable_unique' => 'example', + 'nullable_default' => 2, + 'not_null_default' => 2, + ]]); + } + + public async function testDupeInsertWithValuesFunction(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query("INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'example')"); + await $conn->query( + "INSERT INTO table_with_more_fields (id, name, nullable_unique) VALUES (1, 'test', 'new_example') ON DUPLICATE KEY UPDATE nullable_unique=VALUES(nullable_unique)", + ); + $results = await $conn->query('SELECT * FROM table_with_more_fields'); + expect($results->rows())->toBeSame(vec[dict[ + 'id' => 1, + 'name' => 'test', + 'nullable_unique' => 'new_example', + 'nullable_default' => 1, + 'not_null_default' => 2, + ]]); + } + + public async function testParseComplexWithEscapedJSONAndComment(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query( + "INSERT INTO table_with_more_fields (`id`, `name`, `nullable_unique`) VALUES (56789,'{\\\"tes\\'st\\\":\\\"12345\\\"}','test') /* SQL Comment */", + ); + $results = await $conn->query('SELECT * FROM table_with_more_fields'); + expect($results->rows())->toBeSame(vec[ + dict[ + 'id' => 56789, + 'name' => '{"tes\'st":"12345"}', + 'nullable_unique' => 'test', + 'nullable_default' => 1, + 'not_null_default' => 2, + ], + ]); + } + + public async function testDupeInsertEscaping(): Awaitable { + $conn = static::$conn as nonnull; + await $conn->query(<<<'EOT' + INSERT INTO table1 (`id`,`name`) VALUES (123456789, 'xÚdfíá()ÊÏMÊÏKòáÂÕÿfl©99ùåp>sQj¤Ø©¸¨©=)7±(I{^PSj\\%Krbv*+#©¶ Ì\0Ma\0a\0¤Ý7\\') + ON DUPLICATE KEY UPDATE `name`='xÚdfíá()ÊÏMÊÏKòáÂÕÿfl©99ùåp>sQj¤Ø©¸¨©=)7±(I{^PSj\\%Krbv*+#©¶ Ì\0Ma\0a\0¤Ý7\\' EOT - ); - $results = await $conn->query('SELECT * FROM table1'); - expect($results->rows())->toBeSame(vec[ - dict[ - 'id' => 123456789, - 'name' => - "xÚdfíá()ÊÏMÊÏKòáÂÕÿfl©99ùåp>sQj¤Ø©¸¨©=)7±(I{^PSj\%Krbv*+#©¶ Ì\0Ma\0a\0¤Ý7\\", - ], - ]); - } - - public async function testDupeInsertEscapingNoMangleBinaryWithAddSlashes(): Awaitable { - $conn = static::$conn as nonnull; - $hex = - '78da9391649366eee12829cacf4dcacf4b95147bf0bf75d2a287d30fa432b055a6e6e4e497230b71169726e5261665a22a64909164031ac300c20c0c005c301ea0'; - $bin = \hex2bin($hex); - $bin_for_query = \addslashes($bin); - await $conn->query(<<query('SELECT * FROM table1'); + expect($results->rows())->toBeSame(vec[ + dict[ + 'id' => 123456789, + 'name' => + "xÚdfíá()ÊÏMÊÏKòáÂÕÿfl©99ùåp>sQj¤Ø©¸¨©=)7±(I{^PSj\%Krbv*+#©¶ Ì\0Ma\0a\0¤Ý7\\", + ], + ]); + } + + public async function testDupeInsertEscapingNoMangleBinaryWithAddSlashes(): Awaitable { + $conn = static::$conn as nonnull; + $hex = + '78da9391649366eee12829cacf4dcacf4b95147bf0bf75d2a287d30fa432b055a6e6e4e497230b71169726e5261665a22a64909164031ac300c20c0c005c301ea0'; + $bin = \hex2bin($hex); + $bin_for_query = \addslashes($bin); + await $conn->query(<<query('SELECT * FROM table1'); - expect($results->rows())->toBeSame(vec[ - dict[ - 'id' => 123456789, - 'name' => "$bin", - ], - ]); - } + ); + $results = await $conn->query('SELECT * FROM table1'); + expect($results->rows())->toBeSame(vec[ + dict[ + 'id' => 123456789, + 'name' => "$bin", + ], + ]); + } } diff --git a/tests/JSONFunctionTest.hack b/tests/JSONFunctionTest.hack index c418ed1..ce09f8a 100644 --- a/tests/JSONFunctionTest.hack +++ b/tests/JSONFunctionTest.hack @@ -15,436 +15,433 @@ type JSONFunctionCompositionExpectedType = shape(?'exception' => classname classname, ?'value' => ?int); final class JSONFunctionTest extends HackTest { - private static ?AsyncMysqlConnection $conn; - - <<__Override>> - public static async function beforeFirstTestAsync(): Awaitable { - static::$conn = await SharedSetup::initAsync(); - // block hole logging - Logger::setHandle(new \HH\Lib\IO\MemoryHandle()); - } - - <<__Override>> - public async function beforeEachTestAsync(): Awaitable { - restore('setup'); - QueryContext::$strictSchemaMode = false; - QueryContext::$strictSQLMode = false; - } - - public static async function testJSONValidProvider(): Awaitable> { - return vec[ - // Invalid input - tuple('JSON_VALID()', shape('exception' => SQLFakeRuntimeException::class)), - tuple('JSON_VALID(1, 2)', shape('exception' => SQLFakeRuntimeException::class)), - - // NULL input - tuple('JSON_VALID(NULL)', shape('value' => null)), - - // Valid - tuple("JSON_VALID('null')", shape('value' => 1)), - tuple("JSON_VALID('{\"key\": \"value\"}')", shape('value' => 1)), - tuple("JSON_VALID('[{\"key\": \"value\"}]')", shape('value' => 1)), - tuple("JSON_VALID('\"\"')", shape('value' => 1)), - tuple("JSON_VALID('true')", shape('value' => 1)), - tuple("JSON_VALID('false')", shape('value' => 1)), - tuple("JSON_VALID('2')", shape('value' => 1)), - - // Invalid - tuple("JSON_VALID(' ')", shape('value' => 0)), - tuple("JSON_VALID('arbitrary_string')", shape('value' => 0)), - tuple('JSON_VALID(2)', shape('value' => 0)), - tuple('JSON_VALID(TRUE)', shape('value' => 0)), - ]; - } - - <> - public async function testJSONValid(string $select, JSONValidExpectedType $expected): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - public static async function testJSONQuoteProvider(): Awaitable> { - return vec[ - tuple('JSON_QUOTE(NULL)', shape('value' => null)), - tuple("JSON_QUOTE('null')", shape('value' => '"null"')), - tuple("JSON_QUOTE('a')", shape('value' => '"a"')), - tuple("JSON_QUOTE('{\"a\":\"c\"}')", shape('value' => '"{\"a\":\"c\"}"')), - tuple("JSON_QUOTE('[1, \"\"\"]')", shape('value' => '"[1, \"\"\"]"')), - tuple( - "JSON_QUOTE('[1, \"\\\\\"]')", - shape('value' => '"[1, \"\\\\\"]"'), - ), // In PHP \\ represents \ in single & double quoted strings - tuple("JSON_QUOTE('►')", shape('value' => '"►"')), // MySQL doesn't seem to escape these - tuple("JSON_QUOTE('2\n2')", shape('value' => '"2\\n2"')), // Escapes newline - tuple("JSON_QUOTE('".'22'.\chr(8)."')", shape('value' => '"22\\b"')), // Escapes backspace character - - // invalid - tuple('JSON_QUOTE(TRUE)', shape('exception' => SQLFakeRuntimeException::class)), - tuple('JSON_QUOTE(45)', shape('exception' => SQLFakeRuntimeException::class)), - ]; - } - - <> - public async function testJSONQuote(string $select, JSONQuoteExpectedType $expected): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - public static async function testJSONUnquoteProvider(): Awaitable> { - return vec[ - tuple('JSON_UNQUOTE(NULL)', shape('value' => null)), - tuple('JSON_UNQUOTE("null")', shape('value' => 'null')), - tuple('JSON_UNQUOTE("a")', shape('value' => 'a')), - tuple("JSON_UNQUOTE('\\\\')", shape('value' => '\\')), // no-op - tuple('JSON_UNQUOTE(\'"{\\\\\"a\\\\\":\\\\\"c\\\\\"}"\')', shape('value' => '{"a":"c"}')), - tuple('JSON_UNQUOTE(\'[1, """]\')', shape('value' => '[1, """]')), // no-op - tuple('JSON_UNQUOTE(\'"[1, \\\"\\\"\\\"]"\')', shape('value' => '[1, """]')), - tuple('JSON_UNQUOTE(\'"\\\\\\\\"\')', shape('value' => '\\')), - tuple('JSON_UNQUOTE(\'"\\\u25ba"\')', shape('value' => '►')), - tuple('JSON_UNQUOTE(\'"\\\\n"\')', shape('value' => "\n")), - tuple('JSON_UNQUOTE(\'"2\\\\b"\')', shape('value' => '2'.\chr(8))), - - // invalid - tuple( - 'JSON_UNQUOTE(\'"\\\\"\')', - shape('exception' => SQLFakeRuntimeException::class), - ), // inner function receive '"\"' - tuple('JSON_UNQUOTE(2)', shape('exception' => SQLFakeRuntimeException::class)), - ]; - } - - <> - public async function testJSONUnquote(string $select, JSONUnquoteExpectedType $expected): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - public static async function testJSONExtractProvider(): Awaitable> { - return vec[ - // invalid input - tuple('JSON_EXTRACT()', shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_EXTRACT('[]')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_EXTRACT('[}', '$.a')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_EXTRACT('[]', 'a.b')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_EXTRACT('[]', 2)", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_EXTRACT('[]', TRUE)", shape('exception' => SQLFakeRuntimeException::class)), - - // NULL input - tuple("JSON_EXTRACT(NULL, '$.a')", shape('value' => null)), - tuple("JSON_EXTRACT('{\"a\": {\"b\": \"test\"}}', NULL)", shape('value' => null)), - - // valid - tuple("JSON_EXTRACT('{\"a\": {\"b\": \"test\"}}', '$.a.b')", shape('value' => '"test"')), - tuple("JSON_EXTRACT('{\"a\": {\"b\": 2}}', '$.a.b')", shape('value' => '2')), - tuple("JSON_EXTRACT('{\"a\": {\"b\": true}}', '$.a.b')", shape('value' => 'true')), - tuple( - "JSON_EXTRACT('{\"a\": {\"b\": 2, \"c\": \"test\"}}', '$.a.b', '$.\"a\".\"c\"')", - shape('value' => '[2,"test"]'), - ), - tuple("JSON_EXTRACT('[{\"a\": [{\"b\": \"test\"}]}]', '$[0].a')", shape('value' => '[{"b":"test"}]')), - tuple("JSON_EXTRACT('[{\"a\": [{\"b\": \"test\"}]}]', '$[0].a[0]')", shape('value' => '{"b":"test"}')), - tuple("JSON_EXTRACT('[{\"a\": [{\"b\": \"test\"}]}]', '$**.\"b\"')", shape('value' => '["test"]')), - tuple( - "JSON_EXTRACT('[{\"a\": [{\"b\": \"test\", \"c\":false}]}]', '$**.*')", - shape('value' => '[[{"b":"test","c":false}],"test",false]'), - ), - tuple("JSON_EXTRACT('{\"a\": {\"b\": [\"test\",10]}}', '$**[1]')", shape('value' => '[10]')), - tuple("JSON_EXTRACT('{\"a\":2}', '$.a', '$.b')", shape('value' => '[2]')), - tuple("JSON_EXTRACT('\"a\"', '$')", shape('value' => '"a"')), - - // non-existent - tuple("JSON_EXTRACT('2', '$**[1]')", shape('value' => null)), - tuple("JSON_EXTRACT('{\"a\": 2}', '$.b')", shape('value' => null)), - tuple("JSON_EXTRACT('{\"a\": 2}', '$.b', '$[0]')", shape('value' => null)), - ]; - } - - <> - public async function testJSONExtract(string $select, JSONExtractExpectedType $expected): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - public static async function testJSONReplaceProvider(): Awaitable> { - return vec[ - // NULL inputs - tuple("JSON_REPLACE(NULL, '$.a', 2)", shape('value' => null)), - tuple("JSON_REPLACE('{}', NULL, 2)", shape('value' => null)), - - // bad input - tuple("JSON_REPLACE('{}')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_REPLACE('', '$', 2)", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_REPLACE('{}', '$', 2, '$')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_REPLACE('{}', '\$fd', 2)", shape('exception' => SQLFakeRuntimeException::class)), - - // non-existent path (no-op) - tuple("JSON_REPLACE('null', '$.a', 2)", shape('value' => 'null')), - tuple("JSON_REPLACE('2', '$.a', 45)", shape('value' => '2')), - tuple("JSON_REPLACE('{\"a\": {\"b\":\"test\"}}', '$.b', 45)", shape('value' => '{"a":{"b":"test"}}')), - - // existent path - tuple("JSON_REPLACE('{}', '$', TRUE)", shape('value' => 'true')), - tuple("JSON_REPLACE('[1]', '$[0]', NULL)", shape('value' => '[null]')), - tuple("JSON_REPLACE('{\"a\":{\"b\":\"test\"}}', '$.a.b', 2)", shape('value' => '{"a":{"b":2}}')), - - // multiple - tuple("JSON_REPLACE('{\"a\": [1,2]}', '$.a[0]', 3, '$.a[1]', 4)", shape('value' => '{"a":[3,4]}')), - - // successive replacements work off interim object - tuple("JSON_REPLACE('{\"a\": [1,2]}', '$.a[0]', 3, '$.a[0]', 4)", shape('value' => '{"a":[4,2]}')), - ]; - } - - <> - public async function testJSONReplace(string $select, JSONReplaceExpectedType $expected): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - public static async function testJSONKeysProvider(): Awaitable> { - return vec[ - // invalid input - tuple("JSON_KEYS('{]')", shape('exception' => SQLFakeRuntimeException::class)), - tuple('JSON_KEYS(2)', shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_KEYS('{}', 2)", shape('exception' => SQLFakeRuntimeException::class)), - - // null inputs - tuple('JSON_KEYS(NULL)', shape('value' => null)), - tuple("JSON_KEYS('{}', NULL)", shape('value' => null)), - - // points to non-object - tuple("JSON_KEYS('[]')", shape('value' => null)), - tuple("JSON_KEYS('2')", shape('value' => null)), - tuple("JSON_KEYS('true')", shape('value' => null)), - tuple("JSON_KEYS('null')", shape('value' => null)), - tuple("JSON_KEYS('{\"a\":2}', '$.a')", shape('value' => null)), - tuple("JSON_KEYS('[2]', '$[0]')", shape('value' => null)), - - // path is divergent - tuple("JSON_KEYS('{}', '$[*]')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_KEYS('{}', '$**.a')", shape('exception' => SQLFakeRuntimeException::class)), - - // points to object - tuple("JSON_KEYS('{}')", shape('value' => '[]')), - tuple("JSON_KEYS('{\"a\": {\"b\": 2, \"c\": true}}')", shape('value' => '["a"]')), - tuple("JSON_KEYS('{\"a\": {\"b\": 2, \"c\": true}}', '$.a')", shape('value' => '["b","c"]')), - tuple("JSON_KEYS('[{\"a\": {\"b\": 2, \"c\": true}}]', '$[0]')", shape('value' => '["a"]')), - tuple("JSON_KEYS('[{\"a\": {\"b\": 2, \"c\": true}}]', '$[0].a')", shape('value' => '["b","c"]')), - ]; - } - - <> - public async function testJSONKeys(string $select, JSONKeysExpectedType $expected): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - public static async function testJSONLengthProvider(): Awaitable> { - return vec[ - // invalid input - tuple("JSON_LENGTH('{]')", shape('exception' => SQLFakeRuntimeException::class)), - tuple('JSON_LENGTH(2)', shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_LENGTH('{}', 2)", shape('exception' => SQLFakeRuntimeException::class)), - - // null inputs - tuple('JSON_LENGTH(NULL)', shape('value' => null)), - tuple("JSON_LENGTH('{}', NULL)", shape('value' => null)), - - // path is divergent - tuple("JSON_LENGTH('{}', '$[*]')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_LENGTH('{}', '$**.a')", shape('exception' => SQLFakeRuntimeException::class)), - - // points to object - tuple("JSON_LENGTH('{}')", shape('value' => 0)), - tuple("JSON_LENGTH('{\"a\": {\"b\": 2, \"c\": true}}')", shape('value' => 1)), - tuple("JSON_LENGTH('{\"a\": {\"b\": 2, \"c\": true}}', '$.a')", shape('value' => 2)), - tuple("JSON_LENGTH('[{\"a\": {\"b\": 2, \"c\": true}}]', '$[0]')", shape('value' => 1)), - tuple("JSON_LENGTH('[{\"a\": {\"b\": 2, \"c\": true}}]', '$[0].a')", shape('value' => 2)), - - // points to vec - tuple("JSON_LENGTH('[]')", shape('value' => 0)), - tuple("JSON_LENGTH('[{\"a\": 2}]')", shape('value' => 1)), - tuple("JSON_LENGTH('{\"a\": [true, 1]}', '$.a')", shape('value' => 2)), - tuple("JSON_LENGTH('[[1], 3]', '$[0]')", shape('value' => 1)), - tuple("JSON_LENGTH('[{\"a\": [{\"b\": true}, false, 4]}]', '$[0].a')", shape('value' => 3)), - - // points to scalar - tuple("JSON_LENGTH('\"string\"')", shape('value' => 1)), - tuple("JSON_LENGTH('1')", shape('value' => 1)), - tuple("JSON_LENGTH('true')", shape('value' => 1)), - tuple("JSON_LENGTH('null')", shape('value' => 1)), - tuple("JSON_LENGTH('{\"a\": \"string\", \"b\": 2}', '$.a')", shape('value' => 1)), - tuple("JSON_LENGTH('{\"a\": \"string\", \"b\": 2}', '$.b')", shape('value' => 1)), - tuple("JSON_LENGTH('{\"a\": \"string\", \"b\": true}', '$.b')", shape('value' => 1)), - tuple("JSON_LENGTH('{\"a\": \"string\", \"b\": null}', '$.b')", shape('value' => 1)), - ]; - } - - <> - public async function testJSONLength(string $select, JSONLengthExpectedType $expected): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - public static async function testJSONDepthProvider(): Awaitable> { - return vec[ - // invalid input - tuple('JSON_DEPTH()', shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_DEPTH('{]')", shape('exception' => SQLFakeRuntimeException::class)), - tuple('JSON_DEPTH(2)', shape('exception' => SQLFakeRuntimeException::class)), - tuple('JSON_DEPTH(TRUE)', shape('exception' => SQLFakeRuntimeException::class)), - - // null input - tuple('JSON_DEPTH(NULL)', shape('value' => null)), - - // depth 1 - tuple("JSON_DEPTH('2')", shape('value' => 1)), - tuple("JSON_DEPTH('true')", shape('value' => 1)), - tuple("JSON_DEPTH('false')", shape('value' => 1)), - tuple("JSON_DEPTH('null')", shape('value' => 1)), - tuple("JSON_DEPTH('\"string\"')", shape('value' => 1)), - tuple("JSON_DEPTH('{}')", shape('value' => 1)), - tuple("JSON_DEPTH('[]')", shape('value' => 1)), - - // depth 2 - tuple( - "JSON_DEPTH('{\"a\": 2, \"b\": [], \"c\": \"string\", \"d\": null, \"e\": true, \"f\": {}}')", - shape('value' => 2), - ), - tuple("JSON_DEPTH('[2, {}, [], \"string\", null, false]')", shape('value' => 2)), - - // depth > 2 - tuple("JSON_DEPTH('[[2]]')", shape('value' => 3)), - tuple("JSON_DEPTH('{\"a\": {\"b\": []}}')", shape('value' => 3)), - tuple("JSON_DEPTH('{\"a\": {\"b\": [true]}}')", shape('value' => 4)), - tuple("JSON_DEPTH('[2, 1, [{\"a\": [3, 4]}]]')", shape('value' => 5)), - tuple("JSON_DEPTH('[2, 1, [{\"a\": [3, [{\"a\": [2]}]]}]]')", shape('value' => 8)), - ]; - } - - <> - public async function testJSONDepth(string $select, JSONDepthExpectedType $expected): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - public static async function testJSONFunctionCompositionProvider( - ): Awaitable> { - return vec[ - // JSON as doc for JSON_EXTRACT - tuple("JSON_EXTRACT(JSON_EXTRACT('{\"a\":true}', '$'), '$.a')", shape('value' => 'true')), - tuple("JSON_EXTRACT(JSON_EXTRACT('[true]', '$[0]'), '$')", shape('value' => 'true')), - - // JSON as doc for JSON_REPLACE - tuple( - "JSON_REPLACE(JSON_EXTRACT('{\"a\":{\"b\": 2}}', '$.a'), '$.b', true)", - shape('value' => '{"b":true}'), - ), - tuple("JSON_REPLACE(JSON_EXTRACT(\"[false]\", '$[0]'), '$.a', 'test')", shape('value' => 'false')), - - tuple( - "JSON_REPLACE('[0,1]', '$[1]', REPLACE(JSON_UNQUOTE(JSON_EXTRACT('{\"a\":{\"b\": \"test\"}}', '$.a.b')), 'te', 're'))", - shape('value' => '[0,"rest"]'), - ), - tuple("JSON_REPLACE('{\"b\":2}', '$.b', 1 < 2)", shape('value' => '{"b":true}')), - - // JSON value as replacement in JSON_REPLACE - tuple( - "JSON_REPLACE('[0,1]', '$[1]', JSON_EXTRACT('{\"a\": \"test\"}', '$.a'))", - shape('value' => '[0,"test"]'), - ), - tuple( - "JSON_REPLACE('{\"a\":2}', JSON_UNQUOTE(JSON_EXTRACT('{\"b\":\"$.a\"}', '$.b')), 4)", - shape('value' => '{"a":4}'), - ), - tuple("JSON_UNQUOTE(JSON_EXTRACT('{\"a\": true}', '$.a'))", shape('value' => 'true')), - tuple("JSON_UNQUOTE(JSON_EXTRACT('{\"a\": \"test\"}', '$.a'))", shape('value' => 'test')), - tuple("JSON_UNQUOTE(JSON_EXTRACT('{\"a\": 5}', '$.a'))", shape('value' => '5')), - - // exceptional - // JSON_EXTRACT output as JSON_QUOTE arg - tuple( - "JSON_QUOTE(JSON_EXTRACT('{\"a\":\"test\"}', '$.a'))", - shape('exception' => SQLFakeRuntimeException::class), - ), - // JSON as path for JSON_EXTRACT - tuple( - "JSON_EXTRACT('[true]', JSON_EXTRACT('[\"$[0]\"]', '$[0]'))", - shape('exception' => SQLFakeRuntimeException::class), - ), - // JSON as path for JSON_REPLACE - tuple( - "JSON_REPLACE('{\"a\":2}', JSON_EXTRACT('{\"b\":\"$.a\"}', '$.b'), 2)", - shape('exception' => SQLFakeRuntimeException::class), - ), - ]; - } - - <> - public async function testJSONFunctionComposition( - string $select, - JSONFunctionCompositionExpectedType $expected, - ): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - public static async function testJSONContainsProvider(): Awaitable> { - return vec[ - // NULL inputs - tuple("JSON_CONTAINS(NULL, '{}', '$')", shape('value' => null)), - tuple("JSON_CONTAINS('{}', NULL, '$')", shape('value' => null)), - - // bad input - tuple("JSON_CONTAINS('{}')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_CONTAINS('', 2, '$')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_CONTAINS('{}', 2, '$', '$')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_CONTAINS('{}', '\$fd', 2)", shape('exception' => SQLFakeRuntimeException::class)), - - // non-existent path - tuple("JSON_CONTAINS('null', '2', '$.a')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_CONTAINS('[]', 45, '$.a')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_CONTAINS('2', 45, '$.a')", shape('exception' => SQLFakeRuntimeException::class)), - tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '45', '$.b')", shape('exception' => SQLFakeRuntimeException::class)), - - // existent path - array - tuple("JSON_CONTAINS('[]', '4', '$')", shape('value' => 0)), - tuple("JSON_CONTAINS('[1, 2, 3, 4]', '1', '$')", shape('value' => 1)), - tuple("JSON_CONTAINS('[\"blue\", \"green\", \"red\", \"yellow\"]', '\"red\"', '$')", shape('value' => 1)), - - // existent path - object - tuple("JSON_CONTAINS('{}', '4', '$')", shape('value' => 0)), - tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '\"test\"', '$.a.b')", shape('value' => 1)), - tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '{\"b\":\"test\"}', '$.a')", shape('value' => 1)), - tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '\"test\"', '$.a')", shape('value' => 1)), - - // no path - array - tuple("JSON_CONTAINS('[]', '4')", shape('value' => 0)), - tuple("JSON_CONTAINS('[1, 2, 3, 4]', '1')", shape('value' => 1)), - tuple("JSON_CONTAINS('[\"blue\", \"green\", \"red\", \"yellow\"]', '\"red\"')", shape('value' => 1)), - - // no path - object - tuple("JSON_CONTAINS('{}', '4')", shape('value' => 0)), - tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '\"test\"')", shape('value' => 0)), - tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '{\"a\": {\"b\":\"test\"}}')", shape('value' => 1)), - tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '{\"b\":\"test\"}')", shape('value' => 1)), - ]; - } - - <> - public async function testJSONContains(string $select, JSONContainsExpectedType $expected): Awaitable { - await $this->simpleSelectTestCase($select, $expected); - } - - private async function simpleSelectTestCase( - string $select, - shape(?'exception' => classname, ?'value' => mixed) $expected, - ): Awaitable { - $exception = $expected['exception'] ?? null; - - $conn = static::$conn as nonnull; - $sql = 'SELECT '.$select.' AS expected'; - - if (!$exception) { - $results = await $conn->query($sql); - $expectedValue = $expected['value'] ?? null; - - expect($results->rows())->toBeSame(vec[dict['expected' => $expectedValue]], $sql); - - return; - } - - expect(async () ==> await $conn->query($sql))->toThrow($exception); - } + private static ?AsyncMysqlConnection $conn; + + <<__Override>> + public static async function beforeFirstTestAsync(): Awaitable { + static::$conn = await SharedSetup::initAsync(); + // block hole logging + Logger::setHandle(new \HH\Lib\IO\MemoryHandle()); + } + + <<__Override>> + public async function beforeEachTestAsync(): Awaitable { + restore('setup'); + QueryContext::$strictSchemaMode = false; + QueryContext::$strictSQLMode = false; + } + + public static async function testJSONValidProvider(): Awaitable> { + return vec[ + // Invalid input + tuple('JSON_VALID()', shape('exception' => SQLFakeRuntimeException::class)), + tuple('JSON_VALID(1, 2)', shape('exception' => SQLFakeRuntimeException::class)), + + // NULL input + tuple('JSON_VALID(NULL)', shape('value' => null)), + + // Valid + tuple("JSON_VALID('null')", shape('value' => 1)), + tuple("JSON_VALID('{\"key\": \"value\"}')", shape('value' => 1)), + tuple("JSON_VALID('[{\"key\": \"value\"}]')", shape('value' => 1)), + tuple("JSON_VALID('\"\"')", shape('value' => 1)), + tuple("JSON_VALID('true')", shape('value' => 1)), + tuple("JSON_VALID('false')", shape('value' => 1)), + tuple("JSON_VALID('2')", shape('value' => 1)), + + // Invalid + tuple("JSON_VALID(' ')", shape('value' => 0)), + tuple("JSON_VALID('arbitrary_string')", shape('value' => 0)), + tuple('JSON_VALID(2)', shape('value' => 0)), + tuple('JSON_VALID(TRUE)', shape('value' => 0)), + ]; + } + + <> + public async function testJSONValid(string $select, JSONValidExpectedType $expected): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + public static async function testJSONQuoteProvider(): Awaitable> { + return vec[ + tuple('JSON_QUOTE(NULL)', shape('value' => null)), + tuple("JSON_QUOTE('null')", shape('value' => '"null"')), + tuple("JSON_QUOTE('a')", shape('value' => '"a"')), + tuple("JSON_QUOTE('{\"a\":\"c\"}')", shape('value' => '"{\"a\":\"c\"}"')), + tuple("JSON_QUOTE('[1, \"\"\"]')", shape('value' => '"[1, \"\"\"]"')), + tuple( + "JSON_QUOTE('[1, \"\\\\\"]')", + shape('value' => '"[1, \"\\\\\"]"'), + ), // In PHP \\ represents \ in single & double quoted strings + tuple("JSON_QUOTE('►')", shape('value' => '"►"')), // MySQL doesn't seem to escape these + tuple("JSON_QUOTE('2\n2')", shape('value' => '"2\\n2"')), // Escapes newline + tuple("JSON_QUOTE('".'22'.\chr(8)."')", shape('value' => '"22\\b"')), // Escapes backspace character + + // invalid + tuple('JSON_QUOTE(TRUE)', shape('exception' => SQLFakeRuntimeException::class)), + tuple('JSON_QUOTE(45)', shape('exception' => SQLFakeRuntimeException::class)), + ]; + } + + <> + public async function testJSONQuote(string $select, JSONQuoteExpectedType $expected): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + public static async function testJSONUnquoteProvider(): Awaitable> { + return vec[ + tuple('JSON_UNQUOTE(NULL)', shape('value' => null)), + tuple('JSON_UNQUOTE("null")', shape('value' => 'null')), + tuple('JSON_UNQUOTE("a")', shape('value' => 'a')), + tuple("JSON_UNQUOTE('\\\\')", shape('value' => '\\')), // no-op + tuple('JSON_UNQUOTE(\'"{\\\\\"a\\\\\":\\\\\"c\\\\\"}"\')', shape('value' => '{"a":"c"}')), + tuple('JSON_UNQUOTE(\'[1, """]\')', shape('value' => '[1, """]')), // no-op + tuple('JSON_UNQUOTE(\'"[1, \\\"\\\"\\\"]"\')', shape('value' => '[1, """]')), + tuple('JSON_UNQUOTE(\'"\\\\\\\\"\')', shape('value' => '\\')), + tuple('JSON_UNQUOTE(\'"\\\u25ba"\')', shape('value' => '►')), + tuple('JSON_UNQUOTE(\'"\\\\n"\')', shape('value' => "\n")), + tuple('JSON_UNQUOTE(\'"2\\\\b"\')', shape('value' => '2'.\chr(8))), + + // invalid + tuple( + 'JSON_UNQUOTE(\'"\\\\"\')', + shape('exception' => SQLFakeRuntimeException::class), + ), // inner function receive '"\"' + tuple('JSON_UNQUOTE(2)', shape('exception' => SQLFakeRuntimeException::class)), + ]; + } + + <> + public async function testJSONUnquote(string $select, JSONUnquoteExpectedType $expected): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + public static async function testJSONExtractProvider(): Awaitable> { + return vec[ + // invalid input + tuple('JSON_EXTRACT()', shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_EXTRACT('[]')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_EXTRACT('[}', '$.a')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_EXTRACT('[]', 'a.b')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_EXTRACT('[]', 2)", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_EXTRACT('[]', TRUE)", shape('exception' => SQLFakeRuntimeException::class)), + + // NULL input + tuple("JSON_EXTRACT(NULL, '$.a')", shape('value' => null)), + tuple("JSON_EXTRACT('{\"a\": {\"b\": \"test\"}}', NULL)", shape('value' => null)), + + // valid + tuple("JSON_EXTRACT('{\"a\": {\"b\": \"test\"}}', '$.a.b')", shape('value' => '"test"')), + tuple("JSON_EXTRACT('{\"a\": {\"b\": 2}}', '$.a.b')", shape('value' => '2')), + tuple("JSON_EXTRACT('{\"a\": {\"b\": true}}', '$.a.b')", shape('value' => 'true')), + tuple( + "JSON_EXTRACT('{\"a\": {\"b\": 2, \"c\": \"test\"}}', '$.a.b', '$.\"a\".\"c\"')", + shape('value' => '[2,"test"]'), + ), + tuple("JSON_EXTRACT('[{\"a\": [{\"b\": \"test\"}]}]', '$[0].a')", shape('value' => '[{"b":"test"}]')), + tuple("JSON_EXTRACT('[{\"a\": [{\"b\": \"test\"}]}]', '$[0].a[0]')", shape('value' => '{"b":"test"}')), + tuple("JSON_EXTRACT('[{\"a\": [{\"b\": \"test\"}]}]', '$**.\"b\"')", shape('value' => '["test"]')), + tuple( + "JSON_EXTRACT('[{\"a\": [{\"b\": \"test\", \"c\":false}]}]', '$**.*')", + shape('value' => '[[{"b":"test","c":false}],"test",false]'), + ), + tuple("JSON_EXTRACT('{\"a\": {\"b\": [\"test\",10]}}', '$**[1]')", shape('value' => '[10]')), + tuple("JSON_EXTRACT('{\"a\":2}', '$.a', '$.b')", shape('value' => '[2]')), + tuple("JSON_EXTRACT('\"a\"', '$')", shape('value' => '"a"')), + + // non-existent + tuple("JSON_EXTRACT('2', '$**[1]')", shape('value' => null)), + tuple("JSON_EXTRACT('{\"a\": 2}', '$.b')", shape('value' => null)), + tuple("JSON_EXTRACT('{\"a\": 2}', '$.b', '$[0]')", shape('value' => null)), + ]; + } + + <> + public async function testJSONExtract(string $select, JSONExtractExpectedType $expected): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + public static async function testJSONReplaceProvider(): Awaitable> { + return vec[ + // NULL inputs + tuple("JSON_REPLACE(NULL, '$.a', 2)", shape('value' => null)), + tuple("JSON_REPLACE('{}', NULL, 2)", shape('value' => null)), + + // bad input + tuple("JSON_REPLACE('{}')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_REPLACE('', '$', 2)", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_REPLACE('{}', '$', 2, '$')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_REPLACE('{}', '\$fd', 2)", shape('exception' => SQLFakeRuntimeException::class)), + + // non-existent path (no-op) + tuple("JSON_REPLACE('null', '$.a', 2)", shape('value' => 'null')), + tuple("JSON_REPLACE('2', '$.a', 45)", shape('value' => '2')), + tuple("JSON_REPLACE('{\"a\": {\"b\":\"test\"}}', '$.b', 45)", shape('value' => '{"a":{"b":"test"}}')), + + // existent path + tuple("JSON_REPLACE('{}', '$', TRUE)", shape('value' => 'true')), + tuple("JSON_REPLACE('[1]', '$[0]', NULL)", shape('value' => '[null]')), + tuple("JSON_REPLACE('{\"a\":{\"b\":\"test\"}}', '$.a.b', 2)", shape('value' => '{"a":{"b":2}}')), + + // multiple + tuple("JSON_REPLACE('{\"a\": [1,2]}', '$.a[0]', 3, '$.a[1]', 4)", shape('value' => '{"a":[3,4]}')), + + // successive replacements work off interim object + tuple("JSON_REPLACE('{\"a\": [1,2]}', '$.a[0]', 3, '$.a[0]', 4)", shape('value' => '{"a":[4,2]}')), + ]; + } + + <> + public async function testJSONReplace(string $select, JSONReplaceExpectedType $expected): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + public static async function testJSONKeysProvider(): Awaitable> { + return vec[ + // invalid input + tuple("JSON_KEYS('{]')", shape('exception' => SQLFakeRuntimeException::class)), + tuple('JSON_KEYS(2)', shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_KEYS('{}', 2)", shape('exception' => SQLFakeRuntimeException::class)), + + // null inputs + tuple('JSON_KEYS(NULL)', shape('value' => null)), + tuple("JSON_KEYS('{}', NULL)", shape('value' => null)), + + // points to non-object + tuple("JSON_KEYS('[]')", shape('value' => null)), + tuple("JSON_KEYS('2')", shape('value' => null)), + tuple("JSON_KEYS('true')", shape('value' => null)), + tuple("JSON_KEYS('null')", shape('value' => null)), + tuple("JSON_KEYS('{\"a\":2}', '$.a')", shape('value' => null)), + tuple("JSON_KEYS('[2]', '$[0]')", shape('value' => null)), + + // path is divergent + tuple("JSON_KEYS('{}', '$[*]')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_KEYS('{}', '$**.a')", shape('exception' => SQLFakeRuntimeException::class)), + + // points to object + tuple("JSON_KEYS('{}')", shape('value' => '[]')), + tuple("JSON_KEYS('{\"a\": {\"b\": 2, \"c\": true}}')", shape('value' => '["a"]')), + tuple("JSON_KEYS('{\"a\": {\"b\": 2, \"c\": true}}', '$.a')", shape('value' => '["b","c"]')), + tuple("JSON_KEYS('[{\"a\": {\"b\": 2, \"c\": true}}]', '$[0]')", shape('value' => '["a"]')), + tuple("JSON_KEYS('[{\"a\": {\"b\": 2, \"c\": true}}]', '$[0].a')", shape('value' => '["b","c"]')), + ]; + } + + <> + public async function testJSONKeys(string $select, JSONKeysExpectedType $expected): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + public static async function testJSONLengthProvider(): Awaitable> { + return vec[ + // invalid input + tuple("JSON_LENGTH('{]')", shape('exception' => SQLFakeRuntimeException::class)), + tuple('JSON_LENGTH(2)', shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_LENGTH('{}', 2)", shape('exception' => SQLFakeRuntimeException::class)), + + // null inputs + tuple('JSON_LENGTH(NULL)', shape('value' => null)), + tuple("JSON_LENGTH('{}', NULL)", shape('value' => null)), + + // path is divergent + tuple("JSON_LENGTH('{}', '$[*]')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_LENGTH('{}', '$**.a')", shape('exception' => SQLFakeRuntimeException::class)), + + // points to object + tuple("JSON_LENGTH('{}')", shape('value' => 0)), + tuple("JSON_LENGTH('{\"a\": {\"b\": 2, \"c\": true}}')", shape('value' => 1)), + tuple("JSON_LENGTH('{\"a\": {\"b\": 2, \"c\": true}}', '$.a')", shape('value' => 2)), + tuple("JSON_LENGTH('[{\"a\": {\"b\": 2, \"c\": true}}]', '$[0]')", shape('value' => 1)), + tuple("JSON_LENGTH('[{\"a\": {\"b\": 2, \"c\": true}}]', '$[0].a')", shape('value' => 2)), + + // points to vec + tuple("JSON_LENGTH('[]')", shape('value' => 0)), + tuple("JSON_LENGTH('[{\"a\": 2}]')", shape('value' => 1)), + tuple("JSON_LENGTH('{\"a\": [true, 1]}', '$.a')", shape('value' => 2)), + tuple("JSON_LENGTH('[[1], 3]', '$[0]')", shape('value' => 1)), + tuple("JSON_LENGTH('[{\"a\": [{\"b\": true}, false, 4]}]', '$[0].a')", shape('value' => 3)), + + // points to scalar + tuple("JSON_LENGTH('\"string\"')", shape('value' => 1)), + tuple("JSON_LENGTH('1')", shape('value' => 1)), + tuple("JSON_LENGTH('true')", shape('value' => 1)), + tuple("JSON_LENGTH('null')", shape('value' => 1)), + tuple("JSON_LENGTH('{\"a\": \"string\", \"b\": 2}', '$.a')", shape('value' => 1)), + tuple("JSON_LENGTH('{\"a\": \"string\", \"b\": 2}', '$.b')", shape('value' => 1)), + tuple("JSON_LENGTH('{\"a\": \"string\", \"b\": true}', '$.b')", shape('value' => 1)), + tuple("JSON_LENGTH('{\"a\": \"string\", \"b\": null}', '$.b')", shape('value' => 1)), + ]; + } + + <> + public async function testJSONLength(string $select, JSONLengthExpectedType $expected): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + public static async function testJSONDepthProvider(): Awaitable> { + return vec[ + // invalid input + tuple('JSON_DEPTH()', shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_DEPTH('{]')", shape('exception' => SQLFakeRuntimeException::class)), + tuple('JSON_DEPTH(2)', shape('exception' => SQLFakeRuntimeException::class)), + tuple('JSON_DEPTH(TRUE)', shape('exception' => SQLFakeRuntimeException::class)), + + // null input + tuple('JSON_DEPTH(NULL)', shape('value' => null)), + + // depth 1 + tuple("JSON_DEPTH('2')", shape('value' => 1)), + tuple("JSON_DEPTH('true')", shape('value' => 1)), + tuple("JSON_DEPTH('false')", shape('value' => 1)), + tuple("JSON_DEPTH('null')", shape('value' => 1)), + tuple("JSON_DEPTH('\"string\"')", shape('value' => 1)), + tuple("JSON_DEPTH('{}')", shape('value' => 1)), + tuple("JSON_DEPTH('[]')", shape('value' => 1)), + + // depth 2 + tuple( + "JSON_DEPTH('{\"a\": 2, \"b\": [], \"c\": \"string\", \"d\": null, \"e\": true, \"f\": {}}')", + shape('value' => 2), + ), + tuple("JSON_DEPTH('[2, {}, [], \"string\", null, false]')", shape('value' => 2)), + + // depth > 2 + tuple("JSON_DEPTH('[[2]]')", shape('value' => 3)), + tuple("JSON_DEPTH('{\"a\": {\"b\": []}}')", shape('value' => 3)), + tuple("JSON_DEPTH('{\"a\": {\"b\": [true]}}')", shape('value' => 4)), + tuple("JSON_DEPTH('[2, 1, [{\"a\": [3, 4]}]]')", shape('value' => 5)), + tuple("JSON_DEPTH('[2, 1, [{\"a\": [3, [{\"a\": [2]}]]}]]')", shape('value' => 8)), + ]; + } + + <> + public async function testJSONDepth(string $select, JSONDepthExpectedType $expected): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + public static async function testJSONFunctionCompositionProvider( + ): Awaitable> { + return vec[ + // JSON as doc for JSON_EXTRACT + tuple("JSON_EXTRACT(JSON_EXTRACT('{\"a\":true}', '$'), '$.a')", shape('value' => 'true')), + tuple("JSON_EXTRACT(JSON_EXTRACT('[true]', '$[0]'), '$')", shape('value' => 'true')), + + // JSON as doc for JSON_REPLACE + tuple("JSON_REPLACE(JSON_EXTRACT('{\"a\":{\"b\": 2}}', '$.a'), '$.b', true)", shape('value' => '{"b":true}')), + tuple("JSON_REPLACE(JSON_EXTRACT(\"[false]\", '$[0]'), '$.a', 'test')", shape('value' => 'false')), + + tuple( + "JSON_REPLACE('[0,1]', '$[1]', REPLACE(JSON_UNQUOTE(JSON_EXTRACT('{\"a\":{\"b\": \"test\"}}', '$.a.b')), 'te', 're'))", + shape('value' => '[0,"rest"]'), + ), + tuple("JSON_REPLACE('{\"b\":2}', '$.b', 1 < 2)", shape('value' => '{"b":true}')), + + // JSON value as replacement in JSON_REPLACE + tuple("JSON_REPLACE('[0,1]', '$[1]', JSON_EXTRACT('{\"a\": \"test\"}', '$.a'))", shape('value' => '[0,"test"]')), + tuple( + "JSON_REPLACE('{\"a\":2}', JSON_UNQUOTE(JSON_EXTRACT('{\"b\":\"$.a\"}', '$.b')), 4)", + shape('value' => '{"a":4}'), + ), + tuple("JSON_UNQUOTE(JSON_EXTRACT('{\"a\": true}', '$.a'))", shape('value' => 'true')), + tuple("JSON_UNQUOTE(JSON_EXTRACT('{\"a\": \"test\"}', '$.a'))", shape('value' => 'test')), + tuple("JSON_UNQUOTE(JSON_EXTRACT('{\"a\": 5}', '$.a'))", shape('value' => '5')), + + // exceptional + // JSON_EXTRACT output as JSON_QUOTE arg + tuple( + "JSON_QUOTE(JSON_EXTRACT('{\"a\":\"test\"}', '$.a'))", + shape('exception' => SQLFakeRuntimeException::class), + ), + // JSON as path for JSON_EXTRACT + tuple( + "JSON_EXTRACT('[true]', JSON_EXTRACT('[\"$[0]\"]', '$[0]'))", + shape('exception' => SQLFakeRuntimeException::class), + ), + // JSON as path for JSON_REPLACE + tuple( + "JSON_REPLACE('{\"a\":2}', JSON_EXTRACT('{\"b\":\"$.a\"}', '$.b'), 2)", + shape('exception' => SQLFakeRuntimeException::class), + ), + ]; + } + + <> + public async function testJSONFunctionComposition( + string $select, + JSONFunctionCompositionExpectedType $expected, + ): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + public static async function testJSONContainsProvider(): Awaitable> { + return vec[ + // NULL inputs + tuple("JSON_CONTAINS(NULL, '{}', '$')", shape('value' => null)), + tuple("JSON_CONTAINS('{}', NULL, '$')", shape('value' => null)), + + // bad input + tuple("JSON_CONTAINS('{}')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_CONTAINS('', 2, '$')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_CONTAINS('{}', 2, '$', '$')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_CONTAINS('{}', '\$fd', 2)", shape('exception' => SQLFakeRuntimeException::class)), + + // non-existent path + tuple("JSON_CONTAINS('null', '2', '$.a')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_CONTAINS('[]', 45, '$.a')", shape('exception' => SQLFakeRuntimeException::class)), + tuple("JSON_CONTAINS('2', 45, '$.a')", shape('exception' => SQLFakeRuntimeException::class)), + tuple( + "JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '45', '$.b')", + shape('exception' => SQLFakeRuntimeException::class), + ), + + // existent path - array + tuple("JSON_CONTAINS('[]', '4', '$')", shape('value' => 0)), + tuple("JSON_CONTAINS('[1, 2, 3, 4]', '1', '$')", shape('value' => 1)), + tuple("JSON_CONTAINS('[\"blue\", \"green\", \"red\", \"yellow\"]', '\"red\"', '$')", shape('value' => 1)), + + // existent path - object + tuple("JSON_CONTAINS('{}', '4', '$')", shape('value' => 0)), + tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '\"test\"', '$.a.b')", shape('value' => 1)), + tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '{\"b\":\"test\"}', '$.a')", shape('value' => 1)), + tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '\"test\"', '$.a')", shape('value' => 1)), + + // no path - array + tuple("JSON_CONTAINS('[]', '4')", shape('value' => 0)), + tuple("JSON_CONTAINS('[1, 2, 3, 4]', '1')", shape('value' => 1)), + tuple("JSON_CONTAINS('[\"blue\", \"green\", \"red\", \"yellow\"]', '\"red\"')", shape('value' => 1)), + + // no path - object + tuple("JSON_CONTAINS('{}', '4')", shape('value' => 0)), + tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '\"test\"')", shape('value' => 0)), + tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '{\"a\": {\"b\":\"test\"}}')", shape('value' => 1)), + tuple("JSON_CONTAINS('{\"a\": {\"b\":\"test\"}}', '{\"b\":\"test\"}')", shape('value' => 1)), + ]; + } + + <> + public async function testJSONContains(string $select, JSONContainsExpectedType $expected): Awaitable { + await $this->simpleSelectTestCase($select, $expected); + } + + private async function simpleSelectTestCase( + string $select, + shape(?'exception' => classname, ?'value' => mixed) $expected, + ): Awaitable { + $exception = $expected['exception'] ?? null; + + $conn = static::$conn as nonnull; + $sql = 'SELECT '.$select.' AS expected'; + + if (!$exception) { + $results = await $conn->query($sql); + $expectedValue = $expected['value'] ?? null; + + expect($results->rows())->toBeSame(vec[dict['expected' => $expectedValue]], $sql); + + return; + } + + expect(async () ==> await $conn->query($sql))->toThrow($exception); + } } diff --git a/tests/JSONObjectTest.hack b/tests/JSONObjectTest.hack index 2f2e391..c3dbbaf 100644 --- a/tests/JSONObjectTest.hack +++ b/tests/JSONObjectTest.hack @@ -28,561 +28,561 @@ use type Facebook\HackTest\{DataProvider, HackTest}; use type Slack\SQLFake\JSONPath\{JSONException, JSONObject}; type TestCase = shape( - 'result' => mixed, - 'path' => string, + 'result' => mixed, + 'path' => string, ); final class JSONObjectTest extends HackTest { - private string $json = ' - { - "store": { - "book": [ - { - "category": "reference", - "author": "Nigel Rees", - "title": "Sayings of the Century", - "price": 8.95, - "available": true - }, - { - "category": "fiction", - "author": "Evelyn Waugh", - "title": "Sword of Honour", - "price": 12.99, - "available": false - }, - { - "category": "fiction", - "author": "Herman Melville", - "title": "Moby Dick", - "isbn": "0-553-21311-3", - "price": 8.99, - "available": true - }, - { - "category": "fiction", - "author": "J. R. R. Tolkien", - "title": "The Lord of the Rings", - "isbn": "0-395-19395-8", - "price": 22.99, - "available": false - } - ], - "bicycle": { - "color": "red", - "price": 19.95, - "available": true, - "model": null, - "sku-number": "BCCLE-0001-RD" - } - }, - "authors": [ - "Nigel Rees", - "Evelyn Waugh", - "Herman Melville", - "J. R. R. Tolkien" - ], - "Bike models": [ - 1, - 2, - 3 - ], - "movies": [ - { - "name": "Movie 1", - "director": "Director 1" - } - ], - "$under_$-score3d": 2 - } - '; - - public static async function testGetProvider(): Awaitable> { - return vec[ - tuple('$.store.book[-4, -2, -1]', shape('exceptional' => true)), - tuple('$.store.bicycle.price', shape('value' => vec[19.95])), - tuple('$."store".bicycle."price"', shape('value' => vec[19.95])), - tuple('$."store".bicycle.model', shape('value' => vec[null])), - tuple('$.store.bicycle.sku-number', shape('value' => vec['BCCLE-0001-RD'])), - tuple('$.store.bicycle."sku-number"', shape('value' => vec['BCCLE-0001-RD'])), - tuple('$.$under_$-score3d', shape('value' => vec[2])), - tuple('$."$under_$-score3d"', shape('value' => vec[2])), - tuple('$.-$under_$-score3d', shape('exceptional' => true)), - tuple('$.0$under_$-score3d', shape('exceptional' => true)), - tuple( - '$.store.bicycle', - shape( - 'value' => vec[ - dict[ - 'color' => 'red', - 'price' => 19.95, - 'available' => true, - 'model' => null, - 'sku-number' => 'BCCLE-0001-RD', - ], - ], - ), - ), - tuple('$.store.bicycl', shape('value' => null)), - tuple('$.store.book[*].price', shape('value' => vec[8.95, 12.99, 8.99, 22.99])), - tuple('$.store.book[0][price]', shape('exceptional' => true)), - tuple('$.store.book[0]["price"]', shape('exceptional' => true)), - tuple('$.store.book.0', shape('exceptional' => true)), - tuple('$.store.book[7]', shape('value' => null)), - tuple('$.store.book[1, 2].price', shape('exceptional' => true)), - tuple('$.store.book[*][category, author]', shape('exceptional' => true)), - tuple("$.store.book[*]['category', \"author\"]", shape('exceptional' => true)), - tuple('$.store.book[0:3:2].price', shape('exceptional' => true)), - tuple('$.store.bicycle.price[2]', shape('value' => null)), - tuple('$.store.bicycle.price.*', shape('value' => null)), - tuple( - '$.store.bicycle.*', - shape( - 'value' => vec[ - 'red', - 19.95, - true, - null, - 'BCCLE-0001-RD', - ], - ), - ), - tuple('$**.price', shape('value' => vec[8.95, 12.99, 8.99, 22.99, 19.95])), - tuple('$**."price"', shape('value' => vec[8.95, 12.99, 8.99, 22.99, 19.95])), - tuple('$**price', shape('exceptional' => true)), - tuple('$***.price', shape('exceptional' => true)), - tuple("$.store.book[?(@.category == 'fiction')].price", shape('exceptional' => true)), - tuple('$**[?(@.available == true)].price', shape('exceptional' => true)), - tuple('$**[?(@.available == false)].price', shape('exceptional' => true)), - tuple('$**[?(@.price < 10)].title', shape('exceptional' => true)), - tuple('$**[?(@.price < 10.0)].title', shape('exceptional' => true)), - tuple('$.store.book[?(@.price > 10)].title', shape('exceptional' => true)), - tuple('$**[?(@.author =~ /.*Tolkien/)].title', shape('exceptional' => true)), - tuple('$**[?(@.length <= 5)].color', shape('exceptional' => true)), - tuple('$**[?(@.length <= 5.0)].color', shape('exceptional' => true)), - tuple('$.store.book[?(@.author == $.authors[3])].title', shape('exceptional' => true)), - tuple('$**[?(@.price >= 19.95)][author, color]', shape('exceptional' => true)), - tuple( - "$**[?(@.category == 'fiction' and @.price < 10 or @.color == \"red\")].price", - shape('exceptional' => true), - ), - tuple("$.store.book[?(not @.category == 'fiction')].price", shape('exceptional' => true)), - tuple("$.store.book[?(@.category != 'fiction')].price", shape('exceptional' => true)), - tuple('$**[?(@.color)].color', shape('exceptional' => true)), - tuple("$.store[?(not @..price or @..color == 'red')].available", shape('exceptional' => true)), - tuple('$.store[?(@.price.length == 3)]', shape('exceptional' => true)), - tuple('$.store[?(@.color.length == 3)].price', shape('exceptional' => true)), - tuple('$.store[?(@.color.length == 5)].price', shape('exceptional' => true)), - tuple('$.store[?(@.*.length == 3)]', shape('exceptional' => true)), - tuple('$.store..*[?(@..model == null)].color', shape('exceptional' => true)), - tuple("$['Bike models']", shape('exceptional' => true)), - tuple('$["Bike models"]', shape('exceptional' => true)), - tuple( - '$**[1]', - shape( - 'value' => vec[ - dict[ - 'category' => 'fiction', - 'author' => 'Evelyn Waugh', - 'title' => 'Sword of Honour', - 'price' => 12.99, - 'available' => false, - ], - 'Evelyn Waugh', - 2, - ], - ), - ), - tuple( - '$.movies**.*', - shape( - 'value' => vec[ - 'Movie 1', - 'Director 1', - ], - ), - ), - ]; - } - - <> - public async function testGet( - string $jsonPath, - shape(?'exceptional' => bool, ?'value' => mixed) $output, - ): Awaitable { - $exceptional = $output['exceptional'] ?? false; - - $jsonObject = new JSONObject($this->json); - if (!$exceptional) { - invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); - - $value = $output['value']; - $results = $jsonObject->get($jsonPath); - - if ($value is nonnull) { - expect($results)->toNotBeNull(); - invariant($results is nonnull, 'just making typechecker happy'); - expect($results->value)->toEqual($output['value'], $jsonPath); - } else { - expect($results)->toBeNull(); - } - - return; - } - - expect(() ==> $jsonObject->get($jsonPath))->toThrow(JSONException::class); - } - - public static async function testGetWithUnwrapProvider(): Awaitable> { - return vec[ - tuple('$.store.bicycle.price', shape('value' => 19.95)), - tuple('$."store".bicycle."price"', shape('value' => 19.95)), - tuple( - '$.store.bicycle', - shape( - 'value' => dict[ - 'color' => 'red', - 'price' => 19.95, - 'available' => true, - 'model' => null, - 'sku-number' => 'BCCLE-0001-RD', - ], - ), - ), - tuple( - '$**.bicycle', - shape( - 'value' => vec[ - dict[ - 'color' => 'red', - 'price' => 19.95, - 'available' => true, - 'model' => null, - 'sku-number' => 'BCCLE-0001-RD', - ], - ], - ), - ), - tuple('$**.sku-number', shape('value' => vec['BCCLE-0001-RD'])), - tuple('$.store.bicycl', shape('value' => null)), - tuple('$.store.book[*].price', shape('value' => vec[8.95, 12.99, 8.99, 22.99])), - tuple('$.store.book[7]', shape('value' => null)), - tuple('$.store.bicycle.*', shape('value' => vec['red', 19.95, true, null, 'BCCLE-0001-RD'])), - tuple('$**.price', shape('value' => vec[8.95, 12.99, 8.99, 22.99, 19.95])), - ]; - } - - <> - public async function testGetWithUnwrap( - string $jsonPath, - shape(?'exceptional' => bool, ?'value' => mixed) $output, - ): Awaitable { - $exceptional = $output['exceptional'] ?? false; - - $jsonObject = new JSONObject($this->json); - if (!$exceptional) { - invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); - - $value = $output['value']; - $results = $jsonObject->get($jsonPath, shape('unwrap' => true)); - - if ($value is nonnull) { - expect($results)->toNotBeNull(); - invariant($results is nonnull, 'just making typechecker happy'); - expect($results->value)->toEqual($output['value'], $jsonPath); - } else { - expect($results)->toBeNull(); - } - - return; - } - - expect(() ==> $jsonObject->get($jsonPath, shape('unwrap' => true)))->toThrow(JSONException::class); - } - - public static async function testReplaceProvider(): Awaitable> { - return vec[ - tuple( - shape( - 'json' => dict[ - 'bicycle' => dict[ - 'price' => 19.95, - 'color' => 'red', - 'sku-number' => 'BCCLE-0001-RD', - ], - ] - |> \json_encode($$), - 'path' => '$.bicycle.price', - 'value' => 2000.01, - ), - shape( - 'value' => dict[ - 'bicycle' => dict[ - 'price' => 2000.01, - 'color' => 'red', - 'sku-number' => 'BCCLE-0001-RD', - ], - ], - ), - ), - tuple( - shape( - 'json' => dict[ - 'bicycle' => vec[ - dict[ - 'price' => 19.95, - 'color' => 'red', - 'sku-number' => 'BCCLE-0001-RD', - ], - dict[ - 'price' => 19.96, - 'color' => 'blue', - 'sku-number' => 'BCCLE-0002-RD', - ], - ], - ] - |> \json_encode($$), - 'path' => '$.bicycle[1].price', - 'value' => 2000.01, - ), - shape( - 'value' => dict[ - 'bicycle' => vec[ - dict[ - 'price' => 19.95, - 'color' => 'red', - 'sku-number' => 'BCCLE-0001-RD', - ], - dict[ - 'price' => 2000.01, - 'color' => 'blue', - 'sku-number' => 'BCCLE-0002-RD', - ], - ], - ], - ), - ), - tuple( - shape( - 'json' => dict[ - 'bicycle' => dict[ - 'price' => 19.95, - 'color' => 'red', - 'sku-number' => 'BCCLE-0001-RD', - ], - ], - 'path' => '$.bicycle[*]', - 'value' => 2000, - ), - shape('exceptional' => true), - ), - tuple( - shape( - 'json' => dict[ - 'bicycle' => dict[ - 'price' => 19.95, - 'color' => 'red', - 'sku-number' => 'BCCLE-0001-RD', - ], - ], - 'path' => '$**.color', - 'value' => 'blue', - ), - shape('exceptional' => true), - ), - tuple( - shape( - 'json' => dict[ - 'bicycle' => vec[ - dict[ - 'price' => 19.95, - 'color' => 'red', - 'sku-number' => 'BCCLE-0001-RD', - ], - dict[ - 'price' => 19.96, - 'color' => 'blue', - 'sku-number' => 'BCCLE-0002-RD', - ], - ], - ] - |> \json_encode($$), - 'path' => '$.bicycle[0]', - 'value' => true, - ), - shape( - 'value' => dict[ - 'bicycle' => vec[ - true, - dict[ - 'price' => 19.96, - 'color' => 'blue', - 'sku-number' => 'BCCLE-0002-RD', - ], - ], - ], - ), - ), - tuple(shape('json' => '2', 'path' => '$', 'value' => 3), shape('value' => 3)), - ]; - } - - <> - public async function testReplace( - shape('json' => mixed, 'path' => string, 'value' => mixed) $input, - shape(?'exceptional' => bool, ?'value' => mixed) $output, - ): Awaitable { - $jsonPath = $input['path']; - $value = $input['value']; - - $exceptional = $output['exceptional'] ?? false; - - $jsonObject = new JSONObject($input['json']); - if (!$exceptional) { - invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); - $result = $jsonObject->replace($jsonPath, $value); - expect($result->value->getValue())->toEqual($output['value'], $jsonPath); - return; - } - - expect(() ==> $jsonObject->replace($jsonPath, $value))->toThrow(JSONException::class); - } - - public static async function testKeysProvider(): Awaitable> { - return vec[ - tuple(shape('json' => dict[]), shape('value' => vec[])), - tuple(shape('json' => dict['a' => 2, 'b' => 3]), shape('value' => vec['a', 'b'])), - tuple( - shape('json' => dict['upper' => dict['a' => 2, 'b' => 3]], 'path' => '$.upper'), - shape('value' => vec['a', 'b']), - ), - tuple(shape('json' => vec[dict['a' => 2, 'b' => 3]], 'path' => '$[0]'), shape('value' => vec['a', 'b'])), - tuple(shape('json' => dict['c' => dict['a' => 2, 'b' => 3]], 'path' => '$.c'), shape('value' => vec['a', 'b'])), - - // pointing to non-object - tuple(shape('json' => vec[dict['a' => 2, 'b' => 3]]), shape('value' => null)), - tuple(shape('json' => vec[dict['a' => 2, 'b' => 3]], 'path' => '$[0].a'), shape('value' => null)), - tuple(shape('json' => vec[dict['a' => '2', 'b' => 3]], 'path' => '$[0].a'), shape('value' => null)), - tuple(shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[0].a'), shape('value' => null)), - - // divergent - tuple( - shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[*]'), - shape('exception' => DivergentJSONPathSetException::class), - ), - - // invalid path - tuple( - shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[sdfsf]'), - shape('exception' => InvalidJSONPathException::class), - ), - ]; - } - - <> - public async function testKeys( - shape('json' => mixed, ?'path' => string) $input, - shape(?'exception' => classname, ?'value' => vec) $output, - ): Awaitable { - $jsonPath = $input['path'] ?? null; - $exception = $output['exception'] ?? null; - - $jsonObject = new JSONObject($input['json']); - if (!$exception) { - invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); - $result = $jsonPath ? $jsonObject->keys($jsonPath) : $jsonObject->keys(); - - $expected = $output['value']; - if ($expected is nonnull) { - expect($result)->toNotBeNull(); - invariant($result is nonnull, 'expect statement above verified that this is not null'); - expect($result->value)->toEqual($output['value'], $jsonPath ?? 'no JSON path'); - } else { - expect($result)->toBeNull($jsonPath ?? 'no JSON path'); - } - - return; - } - - expect(() ==> $jsonPath ? $jsonObject->keys($jsonPath) : $jsonObject->keys())->toThrow($exception); - } - - public static async function testLengthProvider(): Awaitable> { - return vec[ - // pointing to object - tuple(shape('json' => dict[]), shape('value' => 0)), - tuple(shape('json' => vec[dict['a' => dict['b' => 2], 'c' => 3]], 'path' => '$[0]'), shape('value' => 2)), - tuple(shape('json' => vec[dict['a' => dict['b' => 2], 'c' => 3]], 'path' => '$[0].a'), shape('value' => 1)), - - // pointing to vector - tuple(shape('json' => vec[2, vec[1]]), shape('value' => 2)), - tuple(shape('json' => vec[vec[2, 3], 3, 3], 'path' => '$[0]'), shape('value' => 2)), - tuple(shape('json' => dict['a' => vec[true, false, true]], 'path' => '$.a'), shape('value' => 3)), - - // pointing to scalar - tuple(shape('json' => '"string"'), shape('value' => 1)), - tuple(shape('json' => 'true'), shape('value' => 1)), - tuple(shape('json' => '1'), shape('value' => 1)), - tuple(shape('json' => 'null'), shape('value' => 1)), - tuple(shape('json' => dict['a' => 'string'], 'path' => '$.a'), shape('value' => 1)), - tuple(shape('json' => dict['a' => false], 'path' => '$.a'), shape('value' => 1)), - tuple(shape('json' => dict['a' => 1], 'path' => '$.a'), shape('value' => 1)), - tuple(shape('json' => dict['a' => null], 'path' => '$.a'), shape('value' => 1)), - - // pointing to nothing - tuple(shape('json' => '{}', 'path' => '$.a'), shape('value' => null)), - - // divergent - tuple( - shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[*]'), - shape('exception' => DivergentJSONPathSetException::class), - ), - - // invalid path - tuple( - shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[sdf]'), - shape('exception' => InvalidJSONPathException::class), - ), - ]; - } - - <> - public async function testLength( - shape('json' => mixed, ?'path' => string) $input, - shape(?'exception' => classname, ?'value' => ?int) $output, - ): Awaitable { - $jsonPath = $input['path'] ?? null; - $exception = $output['exception'] ?? null; - - $jsonObject = new JSONObject($input['json']); - if (!$exception) { - invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); - - $result = $jsonPath ? $jsonObject->length($jsonPath) : $jsonObject->length(); - $expected = $output['value']; - if ($expected is nonnull) { - expect($result)->toNotBeNull(); - invariant($result is nonnull, 'expect statement above verified that this is not null'); - expect($result->value)->toEqual($output['value'], $jsonPath ?? 'no JSON path'); - } else { - expect($result)->toBeNull($jsonPath ?? 'no JSON path'); - } - - return; - } - - expect(() ==> $jsonPath ? $jsonObject->length($jsonPath) : $jsonObject->length())->toThrow($exception); - } - - public static async function testConstructorErrorsProvider(): Awaitable> { - return vec[ - tuple(5), - tuple('{"invalid": json}'), - ]; - } - - <> - public async function testConstructErrors(mixed $json): Awaitable { - expect(() ==> new JSONObject($json))->toThrow(JSONException::class); - } + private string $json = ' + { + "store": { + "book": [ + { + "category": "reference", + "author": "Nigel Rees", + "title": "Sayings of the Century", + "price": 8.95, + "available": true + }, + { + "category": "fiction", + "author": "Evelyn Waugh", + "title": "Sword of Honour", + "price": 12.99, + "available": false + }, + { + "category": "fiction", + "author": "Herman Melville", + "title": "Moby Dick", + "isbn": "0-553-21311-3", + "price": 8.99, + "available": true + }, + { + "category": "fiction", + "author": "J. R. R. Tolkien", + "title": "The Lord of the Rings", + "isbn": "0-395-19395-8", + "price": 22.99, + "available": false + } + ], + "bicycle": { + "color": "red", + "price": 19.95, + "available": true, + "model": null, + "sku-number": "BCCLE-0001-RD" + } + }, + "authors": [ + "Nigel Rees", + "Evelyn Waugh", + "Herman Melville", + "J. R. R. Tolkien" + ], + "Bike models": [ + 1, + 2, + 3 + ], + "movies": [ + { + "name": "Movie 1", + "director": "Director 1" + } + ], + "$under_$-score3d": 2 + } + '; + + public static async function testGetProvider(): Awaitable> { + return vec[ + tuple('$.store.book[-4, -2, -1]', shape('exceptional' => true)), + tuple('$.store.bicycle.price', shape('value' => vec[19.95])), + tuple('$."store".bicycle."price"', shape('value' => vec[19.95])), + tuple('$."store".bicycle.model', shape('value' => vec[null])), + tuple('$.store.bicycle.sku-number', shape('value' => vec['BCCLE-0001-RD'])), + tuple('$.store.bicycle."sku-number"', shape('value' => vec['BCCLE-0001-RD'])), + tuple('$.$under_$-score3d', shape('value' => vec[2])), + tuple('$."$under_$-score3d"', shape('value' => vec[2])), + tuple('$.-$under_$-score3d', shape('exceptional' => true)), + tuple('$.0$under_$-score3d', shape('exceptional' => true)), + tuple( + '$.store.bicycle', + shape( + 'value' => vec[ + dict[ + 'color' => 'red', + 'price' => 19.95, + 'available' => true, + 'model' => null, + 'sku-number' => 'BCCLE-0001-RD', + ], + ], + ), + ), + tuple('$.store.bicycl', shape('value' => null)), + tuple('$.store.book[*].price', shape('value' => vec[8.95, 12.99, 8.99, 22.99])), + tuple('$.store.book[0][price]', shape('exceptional' => true)), + tuple('$.store.book[0]["price"]', shape('exceptional' => true)), + tuple('$.store.book.0', shape('exceptional' => true)), + tuple('$.store.book[7]', shape('value' => null)), + tuple('$.store.book[1, 2].price', shape('exceptional' => true)), + tuple('$.store.book[*][category, author]', shape('exceptional' => true)), + tuple("$.store.book[*]['category', \"author\"]", shape('exceptional' => true)), + tuple('$.store.book[0:3:2].price', shape('exceptional' => true)), + tuple('$.store.bicycle.price[2]', shape('value' => null)), + tuple('$.store.bicycle.price.*', shape('value' => null)), + tuple( + '$.store.bicycle.*', + shape( + 'value' => vec[ + 'red', + 19.95, + true, + null, + 'BCCLE-0001-RD', + ], + ), + ), + tuple('$**.price', shape('value' => vec[8.95, 12.99, 8.99, 22.99, 19.95])), + tuple('$**."price"', shape('value' => vec[8.95, 12.99, 8.99, 22.99, 19.95])), + tuple('$**price', shape('exceptional' => true)), + tuple('$***.price', shape('exceptional' => true)), + tuple("$.store.book[?(@.category == 'fiction')].price", shape('exceptional' => true)), + tuple('$**[?(@.available == true)].price', shape('exceptional' => true)), + tuple('$**[?(@.available == false)].price', shape('exceptional' => true)), + tuple('$**[?(@.price < 10)].title', shape('exceptional' => true)), + tuple('$**[?(@.price < 10.0)].title', shape('exceptional' => true)), + tuple('$.store.book[?(@.price > 10)].title', shape('exceptional' => true)), + tuple('$**[?(@.author =~ /.*Tolkien/)].title', shape('exceptional' => true)), + tuple('$**[?(@.length <= 5)].color', shape('exceptional' => true)), + tuple('$**[?(@.length <= 5.0)].color', shape('exceptional' => true)), + tuple('$.store.book[?(@.author == $.authors[3])].title', shape('exceptional' => true)), + tuple('$**[?(@.price >= 19.95)][author, color]', shape('exceptional' => true)), + tuple( + "$**[?(@.category == 'fiction' and @.price < 10 or @.color == \"red\")].price", + shape('exceptional' => true), + ), + tuple("$.store.book[?(not @.category == 'fiction')].price", shape('exceptional' => true)), + tuple("$.store.book[?(@.category != 'fiction')].price", shape('exceptional' => true)), + tuple('$**[?(@.color)].color', shape('exceptional' => true)), + tuple("$.store[?(not @..price or @..color == 'red')].available", shape('exceptional' => true)), + tuple('$.store[?(@.price.length == 3)]', shape('exceptional' => true)), + tuple('$.store[?(@.color.length == 3)].price', shape('exceptional' => true)), + tuple('$.store[?(@.color.length == 5)].price', shape('exceptional' => true)), + tuple('$.store[?(@.*.length == 3)]', shape('exceptional' => true)), + tuple('$.store..*[?(@..model == null)].color', shape('exceptional' => true)), + tuple("$['Bike models']", shape('exceptional' => true)), + tuple('$["Bike models"]', shape('exceptional' => true)), + tuple( + '$**[1]', + shape( + 'value' => vec[ + dict[ + 'category' => 'fiction', + 'author' => 'Evelyn Waugh', + 'title' => 'Sword of Honour', + 'price' => 12.99, + 'available' => false, + ], + 'Evelyn Waugh', + 2, + ], + ), + ), + tuple( + '$.movies**.*', + shape( + 'value' => vec[ + 'Movie 1', + 'Director 1', + ], + ), + ), + ]; + } + + <> + public async function testGet( + string $jsonPath, + shape(?'exceptional' => bool, ?'value' => mixed) $output, + ): Awaitable { + $exceptional = $output['exceptional'] ?? false; + + $jsonObject = new JSONObject($this->json); + if (!$exceptional) { + invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); + + $value = $output['value']; + $results = $jsonObject->get($jsonPath); + + if ($value is nonnull) { + expect($results)->toNotBeNull(); + invariant($results is nonnull, 'just making typechecker happy'); + expect($results->value)->toEqual($output['value'], $jsonPath); + } else { + expect($results)->toBeNull(); + } + + return; + } + + expect(() ==> $jsonObject->get($jsonPath))->toThrow(JSONException::class); + } + + public static async function testGetWithUnwrapProvider(): Awaitable> { + return vec[ + tuple('$.store.bicycle.price', shape('value' => 19.95)), + tuple('$."store".bicycle."price"', shape('value' => 19.95)), + tuple( + '$.store.bicycle', + shape( + 'value' => dict[ + 'color' => 'red', + 'price' => 19.95, + 'available' => true, + 'model' => null, + 'sku-number' => 'BCCLE-0001-RD', + ], + ), + ), + tuple( + '$**.bicycle', + shape( + 'value' => vec[ + dict[ + 'color' => 'red', + 'price' => 19.95, + 'available' => true, + 'model' => null, + 'sku-number' => 'BCCLE-0001-RD', + ], + ], + ), + ), + tuple('$**.sku-number', shape('value' => vec['BCCLE-0001-RD'])), + tuple('$.store.bicycl', shape('value' => null)), + tuple('$.store.book[*].price', shape('value' => vec[8.95, 12.99, 8.99, 22.99])), + tuple('$.store.book[7]', shape('value' => null)), + tuple('$.store.bicycle.*', shape('value' => vec['red', 19.95, true, null, 'BCCLE-0001-RD'])), + tuple('$**.price', shape('value' => vec[8.95, 12.99, 8.99, 22.99, 19.95])), + ]; + } + + <> + public async function testGetWithUnwrap( + string $jsonPath, + shape(?'exceptional' => bool, ?'value' => mixed) $output, + ): Awaitable { + $exceptional = $output['exceptional'] ?? false; + + $jsonObject = new JSONObject($this->json); + if (!$exceptional) { + invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); + + $value = $output['value']; + $results = $jsonObject->get($jsonPath, shape('unwrap' => true)); + + if ($value is nonnull) { + expect($results)->toNotBeNull(); + invariant($results is nonnull, 'just making typechecker happy'); + expect($results->value)->toEqual($output['value'], $jsonPath); + } else { + expect($results)->toBeNull(); + } + + return; + } + + expect(() ==> $jsonObject->get($jsonPath, shape('unwrap' => true)))->toThrow(JSONException::class); + } + + public static async function testReplaceProvider(): Awaitable> { + return vec[ + tuple( + shape( + 'json' => dict[ + 'bicycle' => dict[ + 'price' => 19.95, + 'color' => 'red', + 'sku-number' => 'BCCLE-0001-RD', + ], + ] + |> \json_encode($$), + 'path' => '$.bicycle.price', + 'value' => 2000.01, + ), + shape( + 'value' => dict[ + 'bicycle' => dict[ + 'price' => 2000.01, + 'color' => 'red', + 'sku-number' => 'BCCLE-0001-RD', + ], + ], + ), + ), + tuple( + shape( + 'json' => dict[ + 'bicycle' => vec[ + dict[ + 'price' => 19.95, + 'color' => 'red', + 'sku-number' => 'BCCLE-0001-RD', + ], + dict[ + 'price' => 19.96, + 'color' => 'blue', + 'sku-number' => 'BCCLE-0002-RD', + ], + ], + ] + |> \json_encode($$), + 'path' => '$.bicycle[1].price', + 'value' => 2000.01, + ), + shape( + 'value' => dict[ + 'bicycle' => vec[ + dict[ + 'price' => 19.95, + 'color' => 'red', + 'sku-number' => 'BCCLE-0001-RD', + ], + dict[ + 'price' => 2000.01, + 'color' => 'blue', + 'sku-number' => 'BCCLE-0002-RD', + ], + ], + ], + ), + ), + tuple( + shape( + 'json' => dict[ + 'bicycle' => dict[ + 'price' => 19.95, + 'color' => 'red', + 'sku-number' => 'BCCLE-0001-RD', + ], + ], + 'path' => '$.bicycle[*]', + 'value' => 2000, + ), + shape('exceptional' => true), + ), + tuple( + shape( + 'json' => dict[ + 'bicycle' => dict[ + 'price' => 19.95, + 'color' => 'red', + 'sku-number' => 'BCCLE-0001-RD', + ], + ], + 'path' => '$**.color', + 'value' => 'blue', + ), + shape('exceptional' => true), + ), + tuple( + shape( + 'json' => dict[ + 'bicycle' => vec[ + dict[ + 'price' => 19.95, + 'color' => 'red', + 'sku-number' => 'BCCLE-0001-RD', + ], + dict[ + 'price' => 19.96, + 'color' => 'blue', + 'sku-number' => 'BCCLE-0002-RD', + ], + ], + ] + |> \json_encode($$), + 'path' => '$.bicycle[0]', + 'value' => true, + ), + shape( + 'value' => dict[ + 'bicycle' => vec[ + true, + dict[ + 'price' => 19.96, + 'color' => 'blue', + 'sku-number' => 'BCCLE-0002-RD', + ], + ], + ], + ), + ), + tuple(shape('json' => '2', 'path' => '$', 'value' => 3), shape('value' => 3)), + ]; + } + + <> + public async function testReplace( + shape('json' => mixed, 'path' => string, 'value' => mixed) $input, + shape(?'exceptional' => bool, ?'value' => mixed) $output, + ): Awaitable { + $jsonPath = $input['path']; + $value = $input['value']; + + $exceptional = $output['exceptional'] ?? false; + + $jsonObject = new JSONObject($input['json']); + if (!$exceptional) { + invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); + $result = $jsonObject->replace($jsonPath, $value); + expect($result->value->getValue())->toEqual($output['value'], $jsonPath); + return; + } + + expect(() ==> $jsonObject->replace($jsonPath, $value))->toThrow(JSONException::class); + } + + public static async function testKeysProvider(): Awaitable> { + return vec[ + tuple(shape('json' => dict[]), shape('value' => vec[])), + tuple(shape('json' => dict['a' => 2, 'b' => 3]), shape('value' => vec['a', 'b'])), + tuple( + shape('json' => dict['upper' => dict['a' => 2, 'b' => 3]], 'path' => '$.upper'), + shape('value' => vec['a', 'b']), + ), + tuple(shape('json' => vec[dict['a' => 2, 'b' => 3]], 'path' => '$[0]'), shape('value' => vec['a', 'b'])), + tuple(shape('json' => dict['c' => dict['a' => 2, 'b' => 3]], 'path' => '$.c'), shape('value' => vec['a', 'b'])), + + // pointing to non-object + tuple(shape('json' => vec[dict['a' => 2, 'b' => 3]]), shape('value' => null)), + tuple(shape('json' => vec[dict['a' => 2, 'b' => 3]], 'path' => '$[0].a'), shape('value' => null)), + tuple(shape('json' => vec[dict['a' => '2', 'b' => 3]], 'path' => '$[0].a'), shape('value' => null)), + tuple(shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[0].a'), shape('value' => null)), + + // divergent + tuple( + shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[*]'), + shape('exception' => DivergentJSONPathSetException::class), + ), + + // invalid path + tuple( + shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[sdfsf]'), + shape('exception' => InvalidJSONPathException::class), + ), + ]; + } + + <> + public async function testKeys( + shape('json' => mixed, ?'path' => string) $input, + shape(?'exception' => classname, ?'value' => vec) $output, + ): Awaitable { + $jsonPath = $input['path'] ?? null; + $exception = $output['exception'] ?? null; + + $jsonObject = new JSONObject($input['json']); + if (!$exception) { + invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); + $result = $jsonPath ? $jsonObject->keys($jsonPath) : $jsonObject->keys(); + + $expected = $output['value']; + if ($expected is nonnull) { + expect($result)->toNotBeNull(); + invariant($result is nonnull, 'expect statement above verified that this is not null'); + expect($result->value)->toEqual($output['value'], $jsonPath ?? 'no JSON path'); + } else { + expect($result)->toBeNull($jsonPath ?? 'no JSON path'); + } + + return; + } + + expect(() ==> $jsonPath ? $jsonObject->keys($jsonPath) : $jsonObject->keys())->toThrow($exception); + } + + public static async function testLengthProvider(): Awaitable> { + return vec[ + // pointing to object + tuple(shape('json' => dict[]), shape('value' => 0)), + tuple(shape('json' => vec[dict['a' => dict['b' => 2], 'c' => 3]], 'path' => '$[0]'), shape('value' => 2)), + tuple(shape('json' => vec[dict['a' => dict['b' => 2], 'c' => 3]], 'path' => '$[0].a'), shape('value' => 1)), + + // pointing to vector + tuple(shape('json' => vec[2, vec[1]]), shape('value' => 2)), + tuple(shape('json' => vec[vec[2, 3], 3, 3], 'path' => '$[0]'), shape('value' => 2)), + tuple(shape('json' => dict['a' => vec[true, false, true]], 'path' => '$.a'), shape('value' => 3)), + + // pointing to scalar + tuple(shape('json' => '"string"'), shape('value' => 1)), + tuple(shape('json' => 'true'), shape('value' => 1)), + tuple(shape('json' => '1'), shape('value' => 1)), + tuple(shape('json' => 'null'), shape('value' => 1)), + tuple(shape('json' => dict['a' => 'string'], 'path' => '$.a'), shape('value' => 1)), + tuple(shape('json' => dict['a' => false], 'path' => '$.a'), shape('value' => 1)), + tuple(shape('json' => dict['a' => 1], 'path' => '$.a'), shape('value' => 1)), + tuple(shape('json' => dict['a' => null], 'path' => '$.a'), shape('value' => 1)), + + // pointing to nothing + tuple(shape('json' => '{}', 'path' => '$.a'), shape('value' => null)), + + // divergent + tuple( + shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[*]'), + shape('exception' => DivergentJSONPathSetException::class), + ), + + // invalid path + tuple( + shape('json' => vec[dict['a' => null, 'b' => 3]], 'path' => '$[sdf]'), + shape('exception' => InvalidJSONPathException::class), + ), + ]; + } + + <> + public async function testLength( + shape('json' => mixed, ?'path' => string) $input, + shape(?'exception' => classname, ?'value' => ?int) $output, + ): Awaitable { + $jsonPath = $input['path'] ?? null; + $exception = $output['exception'] ?? null; + + $jsonObject = new JSONObject($input['json']); + if (!$exception) { + invariant(Shapes::keyExists($output, 'value'), 'expected value must be present in non-exceptional cases'); + + $result = $jsonPath ? $jsonObject->length($jsonPath) : $jsonObject->length(); + $expected = $output['value']; + if ($expected is nonnull) { + expect($result)->toNotBeNull(); + invariant($result is nonnull, 'expect statement above verified that this is not null'); + expect($result->value)->toEqual($output['value'], $jsonPath ?? 'no JSON path'); + } else { + expect($result)->toBeNull($jsonPath ?? 'no JSON path'); + } + + return; + } + + expect(() ==> $jsonPath ? $jsonObject->length($jsonPath) : $jsonObject->length())->toThrow($exception); + } + + public static async function testConstructorErrorsProvider(): Awaitable> { + return vec[ + tuple(5), + tuple('{"invalid": json}'), + ]; + } + + <> + public async function testConstructErrors(mixed $json): Awaitable { + expect(() ==> new JSONObject($json))->toThrow(JSONException::class); + } } diff --git a/tests/JoinQueryTest.php b/tests/JoinQueryTest.php index 4ac73b7..d56d427 100644 --- a/tests/JoinQueryTest.php +++ b/tests/JoinQueryTest.php @@ -49,9 +49,8 @@ final class JoinQueryTest extends HackTest { $results = await $conn->query('SELECT * FROM table3 JOIN association_table ON id = table_3_id'); expect($results->rows())->toBeSame($expected, 'with no aliases and column names inferred from table schema'); - $results = await $conn->query( - 'SELECT * FROM table3 JOIN association_table ON table3.id = association_table.table_3_id', - ); + $results = + await $conn->query('SELECT * FROM table3 JOIN association_table ON table3.id = association_table.table_3_id'); expect($results->rows())->toBeSame($expected, 'with columns using explicitly specified table names'); $results = await $conn->query( @@ -91,9 +90,8 @@ final class JoinQueryTest extends HackTest { public async function testLeftJoin(): Awaitable { $conn = static::$conn as nonnull; - $results = await $conn->query( - 'SELECT id, table_4_id FROM table3 LEFT OUTER JOIN association_table ON id=table_3_id', - ); + $results = + await $conn->query('SELECT id, table_4_id FROM table3 LEFT OUTER JOIN association_table ON id=table_3_id'); expect($results->rows())->toBeSame(vec[ dict['id' => 1, 'table_4_id' => 1000], dict['id' => 1, 'table_4_id' => 1001], @@ -121,9 +119,8 @@ final class JoinQueryTest extends HackTest { public async function testCrossJoin(): Awaitable { $conn = static::$conn as nonnull; - $results = await $conn->query( - 'SELECT table_3_id, id as table_4_id FROM association_table, table4 WHERE table4.id=1003', - ); + $results = + await $conn->query('SELECT table_3_id, id as table_4_id FROM association_table, table4 WHERE table4.id=1003'); expect($results->rows())->toBeSame(vec[ dict['table_3_id' => 1, 'table_4_id' => 1003], dict['table_3_id' => 1, 'table_4_id' => 1003], diff --git a/tests/MultiQueryTest.php b/tests/MultiQueryTest.php index 9e674d4..9d26f30 100644 --- a/tests/MultiQueryTest.php +++ b/tests/MultiQueryTest.php @@ -10,9 +10,8 @@ final class MultiQueryTest extends HackTest { public async function testSubqueryInSelect(): Awaitable { $conn = static::$conn as nonnull; - $results = await $conn->query( - 'SELECT * FROM (SELECT id FROM table4 WHERE id = 1001 OR id = 1004) as sub WHERE id = 1001', - ); + $results = + await $conn->query('SELECT * FROM (SELECT id FROM table4 WHERE id = 1001 OR id = 1004) as sub WHERE id = 1001'); expect($results->rows())->toBeSame(vec[ dict['id' => 1001], ]); diff --git a/tests/SQLFunctionTest.php b/tests/SQLFunctionTest.php index fc6a60d..22dd353 100644 --- a/tests/SQLFunctionTest.php +++ b/tests/SQLFunctionTest.php @@ -36,14 +36,12 @@ final class SQLFunctionTest extends HackTest { $results = await $conn->query('select group_id, count(*) from association_table group by 1'); expect($results->rows())->toBeSame($expected, 'with positional reference in group_by'); - $results = await $conn->query( - 'select group_id, count(*) from association_table group by association_table.group_id', - ); + $results = + await $conn->query('select group_id, count(*) from association_table group by association_table.group_id'); expect($results->rows())->toBeSame($expected, 'with column and alias reference in group_by'); - $results = await $conn->query( - 'select group_id, count(1) from association_table group by association_table.group_id', - ); + $results = + await $conn->query('select group_id, count(1) from association_table group by association_table.group_id'); expect($results->rows())->toBeSame( vec[dict['group_id' => 12345, 'count(1)' => 3], dict['group_id' => 0, 'count(1)' => 1]], 'with count(1) instead of count(*)', @@ -60,9 +58,8 @@ final class SQLFunctionTest extends HackTest { public async function testCountDistinct(): Awaitable { $conn = static::$conn as nonnull; - $results = await $conn->query( - 'select group_id, count(DISTINCT table_3_id) from association_table group by group_id', - ); + $results = + await $conn->query('select group_id, count(DISTINCT table_3_id) from association_table group by group_id'); expect($results->rows())->toBeSame( vec[ dict['group_id' => 12345, 'count(DISTINCT table_3_id)' => 2], @@ -114,9 +111,8 @@ final class SQLFunctionTest extends HackTest { public async function testMinMax(): Awaitable { $conn = static::$conn as nonnull; - $results = await $conn->query( - 'SELECT group_id, MIN(table_4_id), MAX(table_4_id) FROM association_table GROUP BY group_id', - ); + $results = + await $conn->query('SELECT group_id, MIN(table_4_id), MAX(table_4_id) FROM association_table GROUP BY group_id'); expect($results->rows())->toBeSame( vec[ dict[ @@ -135,9 +131,8 @@ final class SQLFunctionTest extends HackTest { public async function testAggNoGroupBy(): Awaitable { $conn = static::$conn as nonnull; - $results = await $conn->query( - 'SELECT COUNT(*), MIN(table_4_id), MAX(table_4_id), AVG(table_4_id) FROM association_table', - ); + $results = + await $conn->query('SELECT COUNT(*), MIN(table_4_id), MAX(table_4_id), AVG(table_4_id) FROM association_table'); expect($results->rows())->toBeSame( vec[ dict[ @@ -370,9 +365,7 @@ final class SQLFunctionTest extends HackTest { // with alias being the same as the column name $results = await $conn->query('select max(id) as id, group_id from table3 group by group_id order by id'); - expect($results->rows())->toBeSame( - vec[dict['id' => 3, 'group_id' => 12345], dict['id' => 6, 'group_id' => 6]], - ); + expect($results->rows())->toBeSame(vec[dict['id' => 3, 'group_id' => 12345], dict['id' => 6, 'group_id' => 6]]); // with positional identifier $results = await $conn->query('select max(id), group_id from table3 group by group_id order by 1'); diff --git a/tests/SelectClauseTest.php b/tests/SelectClauseTest.php index b66d340..06cadb7 100644 --- a/tests/SelectClauseTest.php +++ b/tests/SelectClauseTest.php @@ -63,9 +63,8 @@ final class SelectClauseTest extends HackTest { 'with backtick quoted identifiers', ); - $results = await $conn->query( - 'SELECT table3.id, table3.group_id as my_fav_group_id FROM `db2`.`table3` WHERE group_id=6', - ); + $results = + await $conn->query('SELECT table3.id, table3.group_id as my_fav_group_id FROM `db2`.`table3` WHERE group_id=6'); expect($results->rows())->toBeSame( vec[ dict['id' => 4, 'my_fav_group_id' => 6], @@ -194,9 +193,8 @@ final class SelectClauseTest extends HackTest { public async function testLimit(): Awaitable { $conn = static::$conn as nonnull; - $results = await $conn->query( - 'select group_id, table_4_id from association_table ORDER BY group_id, 2 DESC LIMIT 2', - ); + $results = + await $conn->query('select group_id, table_4_id from association_table ORDER BY group_id, 2 DESC LIMIT 2'); expect($results->rows())->toBeSame( vec[ dict['group_id' => 0, 'table_4_id' => 1003], @@ -235,9 +233,7 @@ final class SelectClauseTest extends HackTest { public async function testOutOfOrderClauses(): Awaitable { $conn = static::$conn as nonnull; - expect( - () ==> $conn->query('select group_id, table_4_id from association_table LIMIT 2 ORDER BY group_id, 2 DESC'), - ) + expect(() ==> $conn->query('select group_id, table_4_id from association_table LIMIT 2 ORDER BY group_id, 2 DESC')) ->toThrow(SQLFakeParseException::class, 'Unexpected ORDER'); } diff --git a/tests/SelectExpressionTest.php b/tests/SelectExpressionTest.php index b9847d4..7a317f9 100644 --- a/tests/SelectExpressionTest.php +++ b/tests/SelectExpressionTest.php @@ -46,9 +46,8 @@ final class SelectExpressionTest extends HackTest { public async function testSelectExpressions(): Awaitable { $conn = static::$conn as nonnull; - $results = await $conn->query( - 'SELECT id, group_id as my_fav_group_id, id*1000 as math FROM table3 WHERE group_id=6', - ); + $results = + await $conn->query('SELECT id, group_id as my_fav_group_id, id*1000 as math FROM table3 WHERE group_id=6'); expect($results->rows())->toBeSame( vec[ dict['id' => 4, 'my_fav_group_id' => 6, 'math' => 4000], @@ -200,9 +199,7 @@ final class SelectExpressionTest extends HackTest { // weird this is even valid SQL, and possibly pedantic, but this demonstrates a lot of how // case statements are implemented such that it doesn't blow up on the second THEN or second CASE - $results = await $conn->query( - "SELECT CASE WHEN 4 = CASE WHEN 1 = 2 THEN 3 ELSE 4 END THEN 'yes' ELSE 'no' END", - ); + $results = await $conn->query("SELECT CASE WHEN 4 = CASE WHEN 1 = 2 THEN 3 ELSE 4 END THEN 'yes' ELSE 'no' END"); expect($results->rows())->toBeSame( vec[ dict["CASE WHEN 4 = CASE WHEN 1 = 2 THEN 3 ELSE 4 END THEN 'yes' ELSE 'no' END" => 'yes'], @@ -341,9 +338,7 @@ final class SelectExpressionTest extends HackTest { public async function testNotParens(): Awaitable { $conn = static::$conn as nonnull; - $results = await $conn->query( - "SELECT id FROM table3 WHERE group_id=12345 AND NOT (name='name1' OR name='name3')", - ); + $results = await $conn->query("SELECT id FROM table3 WHERE group_id=12345 AND NOT (name='name1' OR name='name3')"); expect($results->rows())->toBeSame(vec[ dict['id' => 2], ]); diff --git a/tests/SelectQueryValidatorTest.php b/tests/SelectQueryValidatorTest.php index 106a683..f82412c 100644 --- a/tests/SelectQueryValidatorTest.php +++ b/tests/SelectQueryValidatorTest.php @@ -9,69 +9,69 @@ // https://github.com/vitessio/vitess/blob/master/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt final class SelectQueryValidatorTest extends HackTest { - private static ?AsyncMysqlConnection $conn; + private static ?AsyncMysqlConnection $conn; - <<__Override>> - public static async function beforeFirstTestAsync(): Awaitable { - static::$conn = await SharedSetup::initVitessAsync(); - // block hole logging - Logger::setHandle(new \HH\Lib\IO\MemoryHandle()); - } + <<__Override>> + public static async function beforeFirstTestAsync(): Awaitable { + static::$conn = await SharedSetup::initVitessAsync(); + // block hole logging + Logger::setHandle(new \HH\Lib\IO\MemoryHandle()); + } - <<__Override>> - public async function beforeEachTestAsync(): Awaitable { - restore('vitess_setup'); - QueryContext::$strictSchemaMode = false; - QueryContext::$strictSQLMode = false; - } + <<__Override>> + public async function beforeEachTestAsync(): Awaitable { + restore('vitess_setup'); + QueryContext::$strictSchemaMode = false; + QueryContext::$strictSQLMode = false; + } - public async function testScatterByColumns(): Awaitable { - $conn = static::$conn as nonnull; + public async function testScatterByColumns(): Awaitable { + $conn = static::$conn as nonnull; - $unsupported_test_cases = vec[ - "select * from vt_table1 where id=2 and name='hi' group by id", - 'select * from vt_table1 group by name', - ]; - foreach ($unsupported_test_cases as $sql) { - expect(() ==> $conn->query($sql))->toThrow( - SQLFakeVitessQueryViolation::class, - 'Vitess query validation error: unsupported: in scatter query: group by column must reference column in SELECT list', - ); - } + $unsupported_test_cases = vec[ + "select * from vt_table1 where id=2 and name='hi' group by id", + 'select * from vt_table1 group by name', + ]; + foreach ($unsupported_test_cases as $sql) { + expect(() ==> $conn->query($sql))->toThrow( + SQLFakeVitessQueryViolation::class, + 'Vitess query validation error: unsupported: in scatter query: group by column must reference column in SELECT list', + ); + } - $unsupported_test_cases = vec[ - 'select * from vt_table1 where id in (1, 2) order by id', - 'select id from vt_table1 order by name', - 'select id, count(*) from vt_table1 group by id order by c1', - ]; - $supported_test_cases = vec[ - 'select * from vt_table1 where id=2 order by name', - "select * from vt_table1 where id=2 and name='bob' order by name,id", - ]; - foreach ($unsupported_test_cases as $sql) { - expect(() ==> $conn->query($sql))->toThrow( - SQLFakeVitessQueryViolation::class, - 'Vitess query validation error: unsupported: in scatter query: order by column must reference column in SELECT list', - ); - } - foreach ($supported_test_cases as $sql) { - expect(() ==> $conn->query($sql))->notToThrow(SQLFakeVitessQueryViolation::class); - } - } + $unsupported_test_cases = vec[ + 'select * from vt_table1 where id in (1, 2) order by id', + 'select id from vt_table1 order by name', + 'select id, count(*) from vt_table1 group by id order by c1', + ]; + $supported_test_cases = vec[ + 'select * from vt_table1 where id=2 order by name', + "select * from vt_table1 where id=2 and name='bob' order by name,id", + ]; + foreach ($unsupported_test_cases as $sql) { + expect(() ==> $conn->query($sql))->toThrow( + SQLFakeVitessQueryViolation::class, + 'Vitess query validation error: unsupported: in scatter query: order by column must reference column in SELECT list', + ); + } + foreach ($supported_test_cases as $sql) { + expect(() ==> $conn->query($sql))->notToThrow(SQLFakeVitessQueryViolation::class); + } + } - public async function testUnionsNotAllowed(): Awaitable { - $conn = static::$conn as nonnull; + public async function testUnionsNotAllowed(): Awaitable { + $conn = static::$conn as nonnull; - $test_cases = vec[ - 'select * from vt_table1 union select * from vt_table2', - 'select id from vt_table1 union all select id from vt_table2', - ]; + $test_cases = vec[ + 'select * from vt_table1 union select * from vt_table2', + 'select id from vt_table1 union all select id from vt_table2', + ]; - foreach ($test_cases as $sql) { - expect(() ==> $conn->query($sql))->toThrow( - SQLFakeVitessQueryViolation::class, - 'Vitess query validation error: unsupported: UNION cannot be executed as a single route', - ); - } - } + foreach ($test_cases as $sql) { + expect(() ==> $conn->query($sql))->toThrow( + SQLFakeVitessQueryViolation::class, + 'Vitess query validation error: unsupported: UNION cannot be executed as a single route', + ); + } + } } diff --git a/tests/SharedSetup.php b/tests/SharedSetup.php index e72b46a..c8b91fb 100644 --- a/tests/SharedSetup.php +++ b/tests/SharedSetup.php @@ -185,7 +185,7 @@ final class SharedSetup { 'name' => 'PRIMARY', 'type' => 'PRIMARY', 'fields' => keyset['id'], - ) + ), ], ), 'table_with_more_fields' => shape( @@ -386,7 +386,7 @@ final class SharedSetup { 'name' => 'PRIMARY', 'type' => 'PRIMARY', 'fields' => keyset['id'], - ) + ), ], ), 'association_table' => shape( diff --git a/tests/UpdateQueryTest.php b/tests/UpdateQueryTest.php index 88aa1f5..8624834 100644 --- a/tests/UpdateQueryTest.php +++ b/tests/UpdateQueryTest.php @@ -71,9 +71,7 @@ final class UpdateQueryTest extends HackTest { dict['id' => 4, 'group_id' => 13, 'name' => 'name34updated'], dict['id' => 6, 'group_id' => 13, 'name' => 'name36updated'], ]; - await $conn->query( - "UPDATE `db2`.`table3` set name=CONCAT(name, id, 'updated'), group_id = 13 WHERE group_id=6", - ); + await $conn->query("UPDATE `db2`.`table3` set name=CONCAT(name, id, 'updated'), group_id = 13 WHERE group_id=6"); $results = await $conn->query('SELECT * FROM table3 WHERE group_id=13'); expect($results->rows())->toBeSame($expected, 'with backticks'); } diff --git a/tests/UpdateQueryValidatorTest.php b/tests/UpdateQueryValidatorTest.php index 27ac140..80f8895 100644 --- a/tests/UpdateQueryValidatorTest.php +++ b/tests/UpdateQueryValidatorTest.php @@ -6,49 +6,49 @@ use type Facebook\HackTest\HackTest; final class UpdateQueryValidatorTest extends HackTest { - private static ?AsyncMysqlConnection $conn; - - <<__Override>> - public static async function beforeFirstTestAsync(): Awaitable { - static::$conn = await SharedSetup::initVitessAsync(); - // block hole logging - // ? copied from SelectQueryValidatorTest.php, not sure what that means. - Logger::setHandle(new \HH\Lib\IO\MemoryHandle()); - } - - <<__Override>> - public async function beforeEachTestAsync(): Awaitable { - restore('vitess_setup'); - QueryContext::$strictSchemaMode = false; - QueryContext::$strictSQLMode = false; - } - - public async function testUpdateChangesPrimaryVindex(): Awaitable { - // this is disabled for now (but if Hack knows, it will emit errors because of coeffects) - if ('always true, but opaque to Hack') { - return; - } - $conn = static::$conn as nonnull; - - $unsupported_test_cases = vec[ - 'update vt_table1 set id=1 where id=1', - 'update vt_table2 set vt_table1_id=1 where id=1', - ]; - - foreach ($unsupported_test_cases as $sql) { - expect(() ==> $conn->query($sql))->toThrow( - SQLFakeVitessQueryViolation::class, - 'Vitess query validation error: unsupported: update changes primary vindex column', - ); - } - - $supported_test_cases = vec[ - "update vt_table1 set name='foo' where id = 1", - ]; - - foreach ($supported_test_cases as $sql) { - expect(() ==> $conn->query($sql))->notToThrow(SQLFakeVitessQueryViolation::class); - } - - } + private static ?AsyncMysqlConnection $conn; + + <<__Override>> + public static async function beforeFirstTestAsync(): Awaitable { + static::$conn = await SharedSetup::initVitessAsync(); + // block hole logging + // ? copied from SelectQueryValidatorTest.php, not sure what that means. + Logger::setHandle(new \HH\Lib\IO\MemoryHandle()); + } + + <<__Override>> + public async function beforeEachTestAsync(): Awaitable { + restore('vitess_setup'); + QueryContext::$strictSchemaMode = false; + QueryContext::$strictSQLMode = false; + } + + public async function testUpdateChangesPrimaryVindex(): Awaitable { + // this is disabled for now (but if Hack knows, it will emit errors because of coeffects) + if ('always true, but opaque to Hack') { + return; + } + $conn = static::$conn as nonnull; + + $unsupported_test_cases = vec[ + 'update vt_table1 set id=1 where id=1', + 'update vt_table2 set vt_table1_id=1 where id=1', + ]; + + foreach ($unsupported_test_cases as $sql) { + expect(() ==> $conn->query($sql))->toThrow( + SQLFakeVitessQueryViolation::class, + 'Vitess query validation error: unsupported: update changes primary vindex column', + ); + } + + $supported_test_cases = vec[ + "update vt_table1 set name='foo' where id = 1", + ]; + + foreach ($supported_test_cases as $sql) { + expect(() ==> $conn->query($sql))->notToThrow(SQLFakeVitessQueryViolation::class); + } + + } }