Skip to content

Commit

Permalink
check instance lifecycle before starting the work
Browse files Browse the repository at this point in the history
Co-authored-by: Zeke Hunter-Green <zeke.huntergreen@guardian.co.uk>
  • Loading branch information
marjisound and zekehuntergreen committed Apr 3, 2024
1 parent 0a48c04 commit 72a46e9
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 21 deletions.
3 changes: 0 additions & 3 deletions packages/cdk/lib/transcription-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,6 @@ export class TranscriptionService extends GuStack {
app: workerApp,
}),
},
// we might want to set this to true once we are actually doing transcriptions to protect the instance from
// being terminated before it has a chance to complete a transcription job.
newInstancesProtectedFromScaleIn: true,
mixedInstancesPolicy: {
launchTemplate,
instancesDistribution: {
Expand Down
43 changes: 36 additions & 7 deletions packages/worker/src/asg.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import { SetInstanceProtectionCommand } from '@aws-sdk/client-auto-scaling';
import {
readFile,
getASGClient,
} from '@guardian/transcription-service-backend-common';
AutoScalingClient,
DescribeAutoScalingInstancesCommand,
SetInstanceProtectionCommand,
} from '@aws-sdk/client-auto-scaling';
import { logger } from '@guardian/transcription-service-backend-common';

export const updateScaleInProtection = async (
region: string,
autoScalingClient: AutoScalingClient,
stage: string,
value: boolean,
instanceId: string,
) => {
try {
if (stage !== 'DEV') {
const instanceId = readFile('/var/lib/cloud/data/instance-id');
logger.info(`instanceId retrieved from worker instance: ${instanceId}`);
const autoScalingClient = getASGClient(region);

const input = {
InstanceIds: [instanceId.trim()],
AutoScalingGroupName: `transcription-service-workers-${stage}`,
Expand All @@ -31,3 +31,32 @@ export const updateScaleInProtection = async (
throw error;
}
};

export const getInstanceLifecycleState = async (
autoScalingClient: AutoScalingClient,
stage: string,
instanceId: string,
) => {
try {
if (stage !== 'DEV') {
const input = {
InstanceIds: [instanceId.trim()],
};
const command = new DescribeAutoScalingInstancesCommand(input);
const result = await autoScalingClient.send(command);
const lifecycleState = result.AutoScalingInstances?.find(
(i) => i.InstanceId === instanceId,
)?.LifecycleState;
if (lifecycleState === undefined)
throw new Error('Could not find instance lifecycle state!');

logger.info(`lifecycleState ${lifecycleState}`);
return lifecycleState;
} else {
return undefined;
}
} catch (error) {
logger.error(`Could not retrieve ASG instance lifecycle state`, error);
throw error;
}
};
47 changes: 36 additions & 11 deletions packages/worker/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import {
moveMessageToDeadLetterQueue,
logger,
publishTranscriptionOutput,
readFile,
getASGClient,
} from '@guardian/transcription-service-backend-common';
import {
OutputBucketKeys,
Expand All @@ -25,7 +27,7 @@ import {
} from './transcribe';
import path from 'path';

import { updateScaleInProtection } from './asg';
import { getInstanceLifecycleState, updateScaleInProtection } from './asg';
import { uploadAllTranscriptsToS3 } from './util';
import {
MetricsService,
Expand All @@ -35,6 +37,7 @@ import { SQSClient } from '@aws-sdk/client-sqs';
import { setTimeout } from 'timers/promises';
import { MAX_RECEIVE_COUNT } from '@guardian/transcription-service-common';
import { checkSpotInterrupt } from './spot-termination';
import { AutoScalingClient } from '@aws-sdk/client-auto-scaling';

const POLLING_INTERVAL_SECONDS = 30;

Expand All @@ -46,6 +49,7 @@ export const getCurrentReceiptHandle = () => CURRENT_MESSAGE_RECEIPT_HANDLE;

const main = async () => {
const config = await getConfig();
const instanceId = readFile('/var/lib/cloud/data/instance-id');

const metrics = new MetricsService(
config.app.stage,
Expand All @@ -58,6 +62,8 @@ const main = async () => {
config.aws.localstackEndpoint,
);

const autoScalingClient = getASGClient(config.aws.region);

if (config.app.stage !== 'DEV') {
// start job to regularly check the instance interruption (Note: deliberately not using await here so the job
// runs in the background)
Expand All @@ -68,7 +74,25 @@ const main = async () => {
// keep polling unless instance is scheduled for termination
while (!INTERRUPTION_TIME) {
pollCount += 1;
await pollTranscriptionQueue(pollCount, sqsClient, metrics, config);
const lifecycleState = await getInstanceLifecycleState(
autoScalingClient,
config.app.stage,
instanceId,
);
if (config.app.stage === 'DEV' || lifecycleState === 'InService') {
await pollTranscriptionQueue(
pollCount,
sqsClient,
autoScalingClient,
metrics,
config,
instanceId,
);
} else {
logger.warn(
`instance in state ${lifecycleState} - waiting until it goes to InService.`,
);
}
await setTimeout(POLLING_INTERVAL_SECONDS * 1000);
}
};
Expand All @@ -95,50 +119,51 @@ const publishTranscriptionOutputFailure = async (
const pollTranscriptionQueue = async (
pollCount: number,
sqsClient: SQSClient,
autoScalingClient: AutoScalingClient,
metrics: MetricsService,
config: TranscriptionConfig,
instanceId: string,
) => {
const stage = config.app.stage;
const region = config.aws.region;
const numberOfThreads = config.app.stage === 'PROD' ? 16 : 2;
const isDev = config.app.stage === 'DEV';

logger.info(
`worker polling for transcription task. Poll count = ${pollCount}`,
);

await updateScaleInProtection(region, stage, true);
await updateScaleInProtection(autoScalingClient, stage, true, instanceId);

const message = await getNextMessage(sqsClient, config.app.taskQueueUrl);

if (isSqsFailure(message)) {
logger.error(`Failed to fetch message due to ${message.errorMsg}`);
await updateScaleInProtection(region, stage, false);
await updateScaleInProtection(autoScalingClient, stage, false, instanceId);
return;
}

if (!message.message) {
logger.info('No messages available');
await updateScaleInProtection(region, stage, false);
await updateScaleInProtection(autoScalingClient, stage, false, instanceId);
return;
}

const taskMessage = message.message;
if (!taskMessage.Body) {
logger.error('message missing body');
await updateScaleInProtection(region, stage, false);
await updateScaleInProtection(autoScalingClient, stage, false, instanceId);
return;
}
if (!taskMessage.Attributes && !isDev) {
logger.error('message missing attributes');
await updateScaleInProtection(region, stage, false);
await updateScaleInProtection(autoScalingClient, stage, false, instanceId);
return;
}

const receiptHandle = taskMessage.ReceiptHandle;
if (!receiptHandle) {
logger.error('message missing receipt handle');
await updateScaleInProtection(region, stage, false);
await updateScaleInProtection(autoScalingClient, stage, false, instanceId);
return;
}
CURRENT_MESSAGE_RECEIPT_HANDLE = receiptHandle;
Expand All @@ -148,7 +173,7 @@ const pollTranscriptionQueue = async (
if (!job) {
await metrics.putMetric(FailureMetric);
logger.error('Failed to parse job message', message);
await updateScaleInProtection(region, stage, false);
await updateScaleInProtection(autoScalingClient, stage, false, instanceId);
return;
}

Expand Down Expand Up @@ -304,7 +329,7 @@ const pollTranscriptionQueue = async (
}
} finally {
logger.resetCommonMetadata();
await updateScaleInProtection(region, stage, false);
await updateScaleInProtection(autoScalingClient, stage, false, instanceId);
}
};

Expand Down

0 comments on commit 72a46e9

Please sign in to comment.