add validate callback loading state and encryption

This commit is contained in:
miloschwartz 2025-04-14 20:56:45 -04:00
parent 53be2739bb
commit aa3b527f67
No known key found for this signature in database
11 changed files with 155 additions and 22 deletions

View file

@ -92,7 +92,18 @@ const configSchema = z.object({
}) })
.optional(), .optional(),
trust_proxy: z.boolean().optional().default(true), trust_proxy: z.boolean().optional().default(true),
secret: z.string() secret: z
.string()
.optional()
.transform(getEnvOrYaml("SERVER_SECRET"))
.pipe(
z
.string()
.min(
32,
"SERVER_SECRET must be at least 32 characters long"
)
)
}), }),
traefik: z.object({ traefik: z.object({
http_entrypoint: z.string(), http_entrypoint: z.string(),

37
server/lib/crypto.ts Normal file
View file

@ -0,0 +1,37 @@
import * as crypto from "crypto";
const ALGORITHM = "aes-256-gcm";
export function encrypt(value: string, key: string): string {
const iv = crypto.randomBytes(12);
const cipher = crypto.createCipheriv(ALGORITHM, key, iv);
const encrypted = Buffer.concat([
cipher.update(value, "utf8"),
cipher.final()
]);
const authTag = cipher.getAuthTag();
return [
iv.toString("base64"),
encrypted.toString("base64"),
authTag.toString("base64")
].join(":");
}
export function decrypt(encryptedValue: string, key: string): string {
const [ivB64, encryptedB64, authTagB64] = encryptedValue.split(":");
const iv = Buffer.from(ivB64, "base64");
const encrypted = Buffer.from(encryptedB64, "base64");
const authTag = Buffer.from(authTagB64, "base64");
const decipher = crypto.createDecipheriv(ALGORITHM, key, iv);
decipher.setAuthTag(authTag);
const decrypted = Buffer.concat([
decipher.update(encrypted),
decipher.final()
]);
return decrypted.toString("utf8");
}

View file

@ -9,6 +9,8 @@ import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { idp, idpOidcConfig, idpOrg, orgs } from "@server/db/schemas"; import { idp, idpOidcConfig, idpOrg, orgs } from "@server/db/schemas";
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl"; import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
import { encrypt } from "@server/lib/crypto";
import config from "@server/lib/config";
const paramsSchema = z.object({}).strict(); const paramsSchema = z.object({}).strict();
@ -22,7 +24,8 @@ const bodySchema = z
identifierPath: z.string().nonempty(), identifierPath: z.string().nonempty(),
emailPath: z.string().optional(), emailPath: z.string().optional(),
namePath: z.string().optional(), namePath: z.string().optional(),
scopes: z.array(z.string().nonempty()) scopes: z.array(z.string().nonempty()),
autoProvision: z.boolean().optional()
}) })
.strict(); .strict();
@ -73,9 +76,15 @@ export async function createOidcIdp(
identifierPath, identifierPath,
emailPath, emailPath,
namePath, namePath,
name name,
autoProvision
} = parsedBody.data; } = parsedBody.data;
const key = config.getRawConfig().server.secret;
const encryptedSecret = encrypt(clientSecret, key);
const encryptedClientId = encrypt(clientId, key);
let idpId: number | undefined; let idpId: number | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const [idpRes] = await trx const [idpRes] = await trx
@ -90,11 +99,11 @@ export async function createOidcIdp(
await trx.insert(idpOidcConfig).values({ await trx.insert(idpOidcConfig).values({
idpId: idpRes.idpId, idpId: idpRes.idpId,
clientId, clientId: encryptedClientId,
clientSecret, clientSecret: encryptedSecret,
authUrl, authUrl,
tokenUrl, tokenUrl,
autoProvision: true, autoProvision,
scopes: JSON.stringify(scopes), scopes: JSON.stringify(scopes),
identifierPath, identifierPath,
emailPath, emailPath,

View file

@ -13,6 +13,7 @@ import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
import cookie from "cookie"; import cookie from "cookie";
import jsonwebtoken from "jsonwebtoken"; import jsonwebtoken from "jsonwebtoken";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { decrypt } from "@server/lib/crypto";
const paramsSchema = z const paramsSchema = z
.object({ .object({
@ -77,10 +78,21 @@ export async function generateOidcUrl(
const parsedScopes = JSON.parse(existingIdp.idpOidcConfig.scopes); const parsedScopes = JSON.parse(existingIdp.idpOidcConfig.scopes);
const key = config.getRawConfig().server.secret;
const decryptedClientId = decrypt(
existingIdp.idpOidcConfig.clientId,
key
);
const decryptedClientSecret = decrypt(
existingIdp.idpOidcConfig.clientSecret,
key
);
const redirectUrl = generateOidcRedirectUrl(idpId); const redirectUrl = generateOidcRedirectUrl(idpId);
const client = new arctic.OAuth2Client( const client = new arctic.OAuth2Client(
existingIdp.idpOidcConfig.clientId, decryptedClientId,
existingIdp.idpOidcConfig.clientSecret, decryptedClientSecret,
redirectUrl redirectUrl
); );

View file

@ -28,6 +28,7 @@ import {
generateSessionToken, generateSessionToken,
serializeSessionCookie serializeSessionCookie
} from "@server/auth/sessions/app"; } from "@server/auth/sessions/app";
import { decrypt } from "@server/lib/crypto";
const paramsSchema = z const paramsSchema = z
.object({ .object({
@ -90,10 +91,21 @@ export async function validateOidcCallback(
); );
} }
const key = config.getRawConfig().server.secret;
const decryptedClientId = decrypt(
existingIdp.idpOidcConfig.clientId,
key
);
const decryptedClientSecret = decrypt(
existingIdp.idpOidcConfig.clientSecret,
key
);
const redirectUrl = generateOidcRedirectUrl(existingIdp.idp.idpId); const redirectUrl = generateOidcRedirectUrl(existingIdp.idp.idpId);
const client = new arctic.OAuth2Client( const client = new arctic.OAuth2Client(
existingIdp.idpOidcConfig.clientId, decryptedClientId,
existingIdp.idpOidcConfig.clientSecret, decryptedClientSecret,
redirectUrl redirectUrl
); );

View file

@ -6,6 +6,15 @@ import { ValidateOidcUrlCallbackResponse } from "@server/routers/idp";
import { AxiosResponse } from "axios"; import { AxiosResponse } from "axios";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import {
Card,
CardHeader,
CardTitle,
CardContent,
CardDescription
} from "@/components/ui/card";
import { Alert, AlertDescription } from "@/components/ui/alert";
import { Loader2, CheckCircle2, AlertCircle } from "lucide-react";
type ValidateOidcTokenParams = { type ValidateOidcTokenParams = {
orgId: string; orgId: string;
@ -13,6 +22,7 @@ type ValidateOidcTokenParams = {
code: string | undefined; code: string | undefined;
expectedState: string | undefined; expectedState: string | undefined;
stateCookie: string | undefined; stateCookie: string | undefined;
idp: {name: string};
}; };
export default function ValidateOidcToken(props: ValidateOidcTokenParams) { export default function ValidateOidcToken(props: ValidateOidcTokenParams) {
@ -50,6 +60,9 @@ export default function ValidateOidcToken(props: ValidateOidcTokenParams) {
router.push("/"); router.push("/");
} }
setLoading(false);
await new Promise((resolve) => setTimeout(resolve, 100));
if (redirectUrl.startsWith("http")) { if (redirectUrl.startsWith("http")) {
window.location.href = res.data.data.redirectUrl; // TODO: validate this to make sure it's safe window.location.href = res.data.data.redirectUrl; // TODO: validate this to make sure it's safe
} else { } else {
@ -67,11 +80,36 @@ export default function ValidateOidcToken(props: ValidateOidcTokenParams) {
}, []); }, []);
return ( return (
<> <div className="flex items-center justify-center min-h-screen">
<h1>Validating OIDC Token...</h1> <Card className="w-full max-w-md">
{loading && <p>Loading...</p>} <CardHeader>
{!loading && <p>Token validated successfully!</p>} <CardTitle>Connecting to {props.idp.name}</CardTitle>
{error && <p>Error: {error}</p>} <CardDescription>Validating your identity</CardDescription>
</> </CardHeader>
<CardContent className="flex flex-col items-center space-y-4">
{loading && (
<div className="flex items-center space-x-2">
<Loader2 className="h-5 w-5 animate-spin" />
<span>Connecting...</span>
</div>
)}
{!loading && !error && (
<div className="flex items-center space-x-2 text-green-600">
<CheckCircle2 className="h-5 w-5" />
<span>Connected</span>
</div>
)}
{error && (
<Alert variant="destructive" className="w-full">
<AlertCircle className="h-5 w-5" />
<AlertDescription className="flex flex-col space-y-2">
<span>There was a problem connecting to {props.idp.name}. Please contact your administrator.</span>
<span className="text-xs text-muted-foreground">{error}</span>
</AlertDescription>
</Alert>
)}
</CardContent>
</Card>
</div>
); );
} }

View file

@ -1,5 +1,8 @@
import { cookies } from "next/headers"; import { cookies } from "next/headers";
import ValidateOidcToken from "./ValidateOidcToken"; import ValidateOidcToken from "./ValidateOidcToken";
import { idp } from "@server/db/schemas";
import db from "@server/db";
import { eq } from "drizzle-orm";
export default async function Page(props: { export default async function Page(props: {
params: Promise<{ orgId: string; idpId: string }>; params: Promise<{ orgId: string; idpId: string }>;
@ -14,6 +17,16 @@ export default async function Page(props: {
const allCookies = await cookies(); const allCookies = await cookies();
const stateCookie = allCookies.get("p_oidc_state")?.value; const stateCookie = allCookies.get("p_oidc_state")?.value;
// query db directly in server component because just need the name
const [idpRes] = await db
.select({ name: idp.name })
.from(idp)
.where(eq(idp.idpId, parseInt(params.idpId!)));
if (!idpRes) {
return <div>IdP not found</div>;
}
return ( return (
<> <>
<ValidateOidcToken <ValidateOidcToken
@ -22,6 +35,7 @@ export default async function Page(props: {
code={searchParams.code} code={searchParams.code}
expectedState={searchParams.state} expectedState={searchParams.state}
stateCookie={stateCookie} stateCookie={stateCookie}
idp={{ name: idpRes.name }}
/> />
</> </>
); );

View file

@ -490,7 +490,7 @@ export default function ResourceAuthPortal(props: ResourceAuthPortalProps) {
className={`${numMethods <= 1 ? "mt-0" : ""}`} className={`${numMethods <= 1 ? "mt-0" : ""}`}
> >
<LoginForm <LoginForm
redirect={`/auth/resource/${props.resource.id}`} redirect={props.redirect}
onLogin={async () => onLogin={async () =>
await handleSSOAuth() await handleSSOAuth()
} }

View file

@ -23,7 +23,7 @@
--border: hsl(20 5.9% 90%); --border: hsl(20 5.9% 90%);
--input: hsl(20 5.9% 75%); --input: hsl(20 5.9% 75%);
--ring: hsl(24.6 95% 53.1%); --ring: hsl(24.6 95% 53.1%);
--radius: 0.50rem; --radius: 0.75rem;
--chart-1: hsl(12 76% 61%); --chart-1: hsl(12 76% 61%);
--chart-2: hsl(173 58% 39%); --chart-2: hsl(173 58% 39%);
--chart-3: hsl(197 37% 24%); --chart-3: hsl(197 37% 24%);

View file

@ -24,7 +24,7 @@ import {
import { Alert, AlertDescription } from "@/components/ui/alert"; import { Alert, AlertDescription } from "@/components/ui/alert";
import { LoginResponse } from "@server/routers/auth"; import { LoginResponse } from "@server/routers/auth";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { AxiosResponse, AxiosResponse } from "axios"; import { AxiosResponse } from "axios";
import { formatAxiosError } from "@app/lib/api"; import { formatAxiosError } from "@app/lib/api";
import { LockIcon } from "lucide-react"; import { LockIcon } from "lucide-react";
import { createApiClient } from "@app/lib/api"; import { createApiClient } from "@app/lib/api";
@ -136,7 +136,7 @@ export default function LoginForm({ redirect, onLogin }: LoginFormProps) {
const res = await api.post<AxiosResponse<GenerateOidcUrlResponse>>( const res = await api.post<AxiosResponse<GenerateOidcUrlResponse>>(
`/auth/idp/${idpId}/oidc/generate-url`, `/auth/idp/${idpId}/oidc/generate-url`,
{ {
redirectUrl: redirect || "/" // this is the post auth redirect url redirectUrl: redirect || "/"
} }
); );

View file

@ -9,10 +9,10 @@ const patterns: PatternConfig[] = [
{ name: "Resource Auth Portal", regex: /^\/auth\/resource\/\d+$/ } { name: "Resource Auth Portal", regex: /^\/auth\/resource\/\d+$/ }
]; ];
export function cleanRedirect(input: string): string { export function cleanRedirect(input: string, fallback?: string): string {
if (!input || typeof input !== "string") { if (!input || typeof input !== "string") {
return "/"; return "/";
} }
const isAccepted = patterns.some((pattern) => pattern.regex.test(input)); const isAccepted = patterns.some((pattern) => pattern.regex.test(input));
return isAccepted ? input : "/"; return isAccepted ? input : fallback || "/";
} }