Skip to content

Commit

Permalink
Add RBAC/Admin management on findAllDatasets
Browse files Browse the repository at this point in the history
  • Loading branch information
jreynard-code committed Sep 2, 2024
1 parent d817388 commit 8abf108
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class DatasetServiceIntegrationTest : CsmRedisTestBase() {

@BeforeEach
fun beforeEach() {
every { getCurrentAccountIdentifier(any()) } returns CONNECTED_ADMIN_USER
every { getCurrentAccountIdentifier(any()) } returns TEST_USER_MAIL
every { getCurrentAuthenticatedUserName(csmPlatformProperties) } returns "test.user"
every { getCurrentAuthenticatedRoles(any()) } returns listOf()
rediSearchIndexer.createIndexFor(Connector::class.java)
Expand All @@ -153,11 +153,11 @@ class DatasetServiceIntegrationTest : CsmRedisTestBase() {

connectorSaved = connectorApiService.registerConnector(makeConnector())

organization = makeOrganizationWithRole("Organization")
organization = makeOrganizationWithRole()
organizationSaved = organizationApiService.registerOrganization(organization)
dataset = makeDatasetWithRole()
datasetSaved = datasetApiService.createDataset(organizationSaved.id!!, dataset)
dataset2 = makeDatasetWithRole()
dataset2 = makeDataset()
solution = makeSolution()
solutionSaved = solutionApiService.createSolution(organizationSaved.id!!, solution)
workspace = makeWorkspace()
Expand Down Expand Up @@ -334,6 +334,69 @@ class DatasetServiceIntegrationTest : CsmRedisTestBase() {
assertTrue { datasetCompatibilityList.isEmpty() }
}

@Test
fun `test find All Datasets as Platform Admin`() {
organizationSaved = organizationApiService.registerOrganization(organization)
val numberOfDatasets = 20
val defaultPageSize = csmPlatformProperties.twincache.dataset.defaultPageSize
val expectedSize = 15
IntRange(1, numberOfDatasets).forEach {
datasetApiService.createDataset(
organizationSaved.id!!, makeDataset("d-dataset-$it", "dataset-$it"))
}
logger.info("Change current user...")
every { getCurrentAccountIdentifier(any()) } returns CONNECTED_ADMIN_USER
every { getCurrentAuthenticatedUserName(csmPlatformProperties) } returns "test.admin"
every { getCurrentAuthenticatedRoles(any()) } returns listOf(ROLE_PLATFORM_ADMIN)

logger.info("should find all datasets and assert there are $numberOfDatasets")
var datasetList = datasetApiService.findAllDatasets(organizationSaved.id!!, null, null)
assertEquals(numberOfDatasets + 1, datasetList.size)

logger.info("should find all datasets and assert it equals defaultPageSize: $defaultPageSize")
datasetList = datasetApiService.findAllDatasets(organizationSaved.id!!, 0, null)
assertEquals(defaultPageSize, datasetList.size)

logger.info("should find all datasets and assert there are expected size: $expectedSize")
datasetList = datasetApiService.findAllDatasets(organizationSaved.id!!, 0, expectedSize)
assertEquals(expectedSize, datasetList.size)

logger.info("should find all solutions and assert it returns the second / last page")
datasetList = datasetApiService.findAllDatasets(organizationSaved.id!!, 1, expectedSize)
assertEquals(numberOfDatasets - expectedSize + 1, datasetList.size)
}

@Test
fun `test find All Datasets as Organization User`() {
organizationSaved = organizationApiService.registerOrganization(organization)
val numberOfDatasets = 20
val defaultPageSize = csmPlatformProperties.twincache.dataset.defaultPageSize
val expectedSize = 15
IntRange(1, numberOfDatasets).forEach {
datasetApiService.createDataset(
organizationSaved.id!!,
makeDatasetWithRole(
organizationId = "d-dataset-$it",
parentId = "dataset-$it",
userName = "ANOTHER_USER"))
}
logger.info("should find all datasets and assert there are $numberOfDatasets")
var datasetList = datasetApiService.findAllDatasets(organizationSaved.id!!, null, null)
assertEquals(0, datasetList.size)

logger.info("should find all datasets and assert it equals defaultPageSize: $defaultPageSize")
datasetList = datasetApiService.findAllDatasets(organizationSaved.id!!, 0, null)
assertEquals(0, datasetList.size)

logger.info("should find all datasets and assert there are expected size: $expectedSize")
datasetList = datasetApiService.findAllDatasets(organizationSaved.id!!, 0, expectedSize)
assertEquals(0, datasetList.size)

logger.info("should find all solutions and assert it returns the second / last page")
datasetList = datasetApiService.findAllDatasets(organizationSaved.id!!, 1, expectedSize)
assertEquals(0, datasetList.size)
}

@Test
fun `test find All Datasets with different pagination params`() {
organizationSaved = organizationApiService.registerOrganization(organization)
Expand Down Expand Up @@ -650,8 +713,7 @@ class DatasetServiceIntegrationTest : CsmRedisTestBase() {
@Test
fun `access control list shouldn't contain more than one time each user on creation`() {
connectorSaved = connectorApiService.registerConnector(makeConnector())
organizationSaved =
organizationApiService.registerOrganization(makeOrganizationWithRole("organization"))
organizationSaved = organizationApiService.registerOrganization(makeOrganizationWithRole())
val brokenDataset =
Dataset(
name = "dataset",
Expand All @@ -671,8 +733,7 @@ class DatasetServiceIntegrationTest : CsmRedisTestBase() {
@Test
fun `access control list shouldn't contain more than one time each user on ACL addition`() {
connectorSaved = connectorApiService.registerConnector(makeConnector())
organizationSaved =
organizationApiService.registerOrganization(makeOrganizationWithRole("organization"))
organizationSaved = organizationApiService.registerOrganization(makeOrganizationWithRole())
val workingDataset = makeDatasetWithRole("dataset", sourceType = DatasetSourceType.None)
val datasetSaved = datasetApiService.createDataset(organizationSaved.id!!, workingDataset)

Expand Down Expand Up @@ -925,7 +986,7 @@ class DatasetServiceIntegrationTest : CsmRedisTestBase() {

fun makeOrganizationWithRole(
userName: String = TEST_USER_MAIL,
role: String = ROLE_ADMIN
role: String = ROLE_EDITOR
): Organization {
return Organization(
id = UUID.randomUUID().toString(),
Expand All @@ -939,6 +1000,23 @@ class DatasetServiceIntegrationTest : CsmRedisTestBase() {
OrganizationAccessControl(id = CONNECTED_ADMIN_USER, role = ROLE_ADMIN),
OrganizationAccessControl(id = userName, role = role))))
}
fun makeDataset(
organizationId: String = organizationSaved.id!!,
parentId: String = "",
sourceType: DatasetSourceType = DatasetSourceType.File
): Dataset {
return Dataset(
id = UUID.randomUUID().toString(),
name = "My datasetRbac",
organizationId = organizationId,
parentId = parentId,
ownerId = "ownerId",
connector = DatasetConnector(connectorSaved.id!!),
twingraphId = "graph",
source = SourceInfo("location", "name", "path"),
tags = mutableListOf("dataset"),
sourceType = sourceType)
}

fun makeDatasetWithRole(
organizationId: String = organizationSaved.id!!,
Expand Down Expand Up @@ -967,7 +1045,11 @@ class DatasetServiceIntegrationTest : CsmRedisTestBase() {
DatasetAccessControl(id = userName, role = role))))
}

fun makeSolution(organizationId: String = organizationSaved.id!!): Solution {
fun makeSolution(
organizationId: String = organizationSaved.id!!,
userName: String = TEST_USER_MAIL,
role: String = ROLE_EDITOR
): Solution {
return Solution(
id = "solutionId",
key = UUID.randomUUID().toString(),
Expand All @@ -979,7 +1061,8 @@ class DatasetServiceIntegrationTest : CsmRedisTestBase() {
default = ROLE_NONE,
accessControlList =
mutableListOf(
SolutionAccessControl(id = CONNECTED_ADMIN_USER, role = ROLE_ADMIN))))
SolutionAccessControl(id = CONNECTED_ADMIN_USER, role = ROLE_ADMIN),
SolutionAccessControl(id = userName, role = role))))
}

fun makeWorkspace(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import com.cosmotech.api.events.TwingraphImportJobInfoRequest
import com.cosmotech.api.exceptions.CsmAccessForbiddenException
import com.cosmotech.api.exceptions.CsmClientException
import com.cosmotech.api.exceptions.CsmResourceNotFoundException
import com.cosmotech.api.rbac.CsmAdmin
import com.cosmotech.api.rbac.CsmRbac
import com.cosmotech.api.rbac.PERMISSION_CREATE_CHILDREN
import com.cosmotech.api.rbac.PERMISSION_DELETE
Expand Down Expand Up @@ -125,21 +126,40 @@ class DatasetServiceImpl(
private val datasetRepository: DatasetRepository,
private val unifiedJedis: UnifiedJedis,
private val csmRbac: CsmRbac,
private val csmAdmin: CsmAdmin,
private val resourceScanner: ResourceScanner
) : CsmPhoenixService(), DatasetApiServiceInterface {

override fun findAllDatasets(organizationId: String, page: Int?, size: Int?): List<Dataset> {
organizationService.getVerifiedOrganization(organizationId)

val currentUser = getCurrentAccountIdentifier(this.csmPlatformProperties)
val defaultPageSize = csmPlatformProperties.twincache.dataset.defaultPageSize
val pageable = constructPageRequest(page, size, defaultPageSize)
if (pageable != null) {
return datasetRepository.findByOrganizationId(organizationId, currentUser, pageable).toList()
}
return findAllPaginated(defaultPageSize) {
datasetRepository.findByOrganizationId(organizationId, currentUser, it).toList()
val isAdmin = csmAdmin.verifyCurrentRolesAdmin()
val result: MutableList<Dataset>

val rbacEnabled = !isAdmin && this.csmPlatformProperties.rbac.enabled

if (pageable == null) {
result =
findAllPaginated(defaultPageSize) {
if (rbacEnabled) {
val currentUser = getCurrentAccountIdentifier(this.csmPlatformProperties)
datasetRepository.findByOrganizationId(organizationId, currentUser, it).toList()
} else {
datasetRepository.findAll(it).toList()
}
}
} else {
result =
if (rbacEnabled) {
val currentUser = getCurrentAccountIdentifier(this.csmPlatformProperties)
datasetRepository.findByOrganizationId(organizationId, currentUser, pageable).toList()
} else {
datasetRepository.findAll(pageable).toList()
}
}

return result
}

override fun findDatasetById(organizationId: String, datasetId: String): Dataset {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.core.io.ByteArrayResource
import org.springframework.data.domain.Page
import org.springframework.data.domain.Pageable
import org.springframework.web.context.request.RequestContextHolder
import org.springframework.web.context.request.ServletRequestAttributes
import redis.clients.jedis.UnifiedJedis
Expand Down Expand Up @@ -100,10 +101,10 @@ class DatasetServiceImplTests {
@Test
fun `findAllDatasets should return empty list when no dataset exists`() {
every { organizationService.getVerifiedOrganization(ORGANIZATION_ID) } returns Organization()
every { datasetRepository.findByOrganizationId(ORGANIZATION_ID, any(), any()) } returns
Page.empty()
every { datasetRepository.findAll(any<Pageable>()) } returns Page.empty()

val result = datasetService.findAllDatasets(ORGANIZATION_ID, null, null)
assertEquals(emptyList<Dataset>(), result)
assertEquals(emptyList(), result)
}

@Test
Expand Down

0 comments on commit 8abf108

Please sign in to comment.